Skip to content
Snippets Groups Projects
Commit ae3f0148 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adding mask_uniform function

parent 5119a79d
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import glob import glob
import shutil import shutil
import gc import gc
import random
import logging import logging
logger = logging.getLogger("KerasROOTClassification") logger = logging.getLogger("KerasROOTClassification")
...@@ -1726,6 +1727,21 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1726,6 +1727,21 @@ class ClassificationProjectRNN(ClassificationProject):
evt[line_idx] = self.mask_value evt[line_idx] = self.mask_value
def mask_uniform(self, x):
"""
Mask recurrent fields with a random (uniform) number of objects. Works in place.
"""
for recurrent_field_idx in self.recurrent_field_idx:
for evt in x:
masked = False
nobj = int(random.random()*(recurrent_field_idx.shape[1]+1))
for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])):
if obj_number == nobj:
masked=True
if masked:
evt[line_idx] = self.mask_value
def get_input_list(self, x): def get_input_list(self, x):
"Format the input starting from flat ntuple" "Format the input starting from flat ntuple"
x_input = [] x_input = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment