Skip to content
Snippets Groups Projects
Commit 14a7743a authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

wrote function to transform input back into flat for RNN

parent 9ee7e327
No related branches found
No related tags found
No related merge requests found
......@@ -1721,6 +1721,22 @@ class ClassificationProjectRNN(ClassificationProject):
return x_input
def get_input_flat(self, x):
"Transform input back to flat ntuple"
nevent = x[0].shape[0]
x_flat = np.empty((nevent, len(self.fields)), dtype=np.float)
# recurrent fields
for rec_ar, idx in zip(x, self.recurrent_field_idx):
idx = idx.reshape(-1)
for source_idx, target_idx in enumerate(idx):
x_flat[:,target_idx] = rec_ar.reshape(nevent, -1)[:,source_idx]
# flat fields
for source_idx, field_name in enumerate(self.flat_fields):
target_idx = self.fields.index(field_name)
x_flat[:,target_idx] = x[-1][:,source_idx]
return x_flat
def yield_batch(self):
x_train, y_train, w_train = self.training_data
while True:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment