From ae3f01480d579f1892f1cd2036fbe8d118845d5a Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Wed, 22 Aug 2018 14:18:55 +0200
Subject: [PATCH] adding mask_uniform function

---
 toolkit.py | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/toolkit.py b/toolkit.py
index 95fdcbd..5b33e00 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -19,6 +19,7 @@ import math
 import glob
 import shutil
 import gc
+import random
 
 import logging
 logger = logging.getLogger("KerasROOTClassification")
@@ -1726,6 +1727,21 @@ class ClassificationProjectRNN(ClassificationProject):
                         evt[line_idx] = self.mask_value
 
 
+    def mask_uniform(self, x):
+        """
+        Mask recurrent fields with a random (uniform) number of objects. Works in place.
+        """
+        for recurrent_field_idx in self.recurrent_field_idx:
+            for evt in x:
+                masked = False
+                nobj = int(random.random()*(recurrent_field_idx.shape[1]+1))
+                for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])):
+                    if obj_number == nobj:
+                        masked=True
+                    if masked:
+                        evt[line_idx] = self.mask_value
+
+
     def get_input_list(self, x):
         "Format the input starting from flat ntuple"
         x_input = []
-- 
GitLab