From 14a7743ad8818311b1251d6b0b6193cab2d3c08e Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 21 Aug 2018 16:04:31 +0200 Subject: [PATCH] wrote function to transform input back into flat for RNN --- toolkit.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/toolkit.py b/toolkit.py index 834d977..c3040d8 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1721,6 +1721,22 @@ class ClassificationProjectRNN(ClassificationProject): return x_input + def get_input_flat(self, x): + "Transform input back to flat ntuple" + nevent = x[0].shape[0] + x_flat = np.empty((nevent, len(self.fields)), dtype=np.float) + # recurrent fields + for rec_ar, idx in zip(x, self.recurrent_field_idx): + idx = idx.reshape(-1) + for source_idx, target_idx in enumerate(idx): + x_flat[:,target_idx] = rec_ar.reshape(nevent, -1)[:,source_idx] + # flat fields + for source_idx, field_name in enumerate(self.flat_fields): + target_idx = self.fields.index(field_name) + x_flat[:,target_idx] = x[-1][:,source_idx] + return x_flat + + def yield_batch(self): x_train, y_train, w_train = self.training_data while True: -- GitLab