diff --git a/toolkit.py b/toolkit.py index 95fdcbd2a8cad811aae9e3fcbee91ec4003ee79c..5b33e0052d145a22a6210d9c4e831f29c260de61 100755 --- a/toolkit.py +++ b/toolkit.py @@ -19,6 +19,7 @@ import math import glob import shutil import gc +import random import logging logger = logging.getLogger("KerasROOTClassification") @@ -1726,6 +1727,21 @@ class ClassificationProjectRNN(ClassificationProject): evt[line_idx] = self.mask_value + def mask_uniform(self, x): + """ + Mask recurrent fields with a random (uniform) number of objects. Works in place. + """ + for recurrent_field_idx in self.recurrent_field_idx: + for evt in x: + masked = False + nobj = int(random.random()*(recurrent_field_idx.shape[1]+1)) + for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])): + if obj_number == nobj: + masked=True + if masked: + evt[line_idx] = self.mask_value + + def get_input_list(self, x): "Format the input starting from flat ntuple" x_input = []