From ae3f01480d579f1892f1cd2036fbe8d118845d5a Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Wed, 22 Aug 2018 14:18:55 +0200 Subject: [PATCH] adding mask_uniform function --- toolkit.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/toolkit.py b/toolkit.py index 95fdcbd..5b33e00 100755 --- a/toolkit.py +++ b/toolkit.py @@ -19,6 +19,7 @@ import math import glob import shutil import gc +import random import logging logger = logging.getLogger("KerasROOTClassification") @@ -1726,6 +1727,21 @@ class ClassificationProjectRNN(ClassificationProject): evt[line_idx] = self.mask_value + def mask_uniform(self, x): + """ + Mask recurrent fields with a random (uniform) number of objects. Works in place. + """ + for recurrent_field_idx in self.recurrent_field_idx: + for evt in x: + masked = False + nobj = int(random.random()*(recurrent_field_idx.shape[1]+1)) + for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])): + if obj_number == nobj: + masked=True + if masked: + evt[line_idx] = self.mask_value + + def get_input_list(self, x): "Format the input starting from flat ntuple" x_input = [] -- GitLab