Skip to content
Snippets Groups Projects
Commit ae3f0148 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adding mask_uniform function

parent 5119a79d
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ import math
import glob
import shutil
import gc
import random
import logging
logger = logging.getLogger("KerasROOTClassification")
......@@ -1726,6 +1727,21 @@ class ClassificationProjectRNN(ClassificationProject):
evt[line_idx] = self.mask_value
def mask_uniform(self, x):
"""
Mask recurrent fields with a random (uniform) number of objects. Works in place.
"""
for recurrent_field_idx in self.recurrent_field_idx:
for evt in x:
masked = False
nobj = int(random.random()*(recurrent_field_idx.shape[1]+1))
for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])):
if obj_number == nobj:
masked=True
if masked:
evt[line_idx] = self.mask_value
def get_input_list(self, x):
"Format the input starting from flat ntuple"
x_input = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment