From 14a7743ad8818311b1251d6b0b6193cab2d3c08e Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 21 Aug 2018 16:04:31 +0200
Subject: [PATCH] wrote function to transform input back into flat for RNN

---
 toolkit.py | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/toolkit.py b/toolkit.py
index 834d977..c3040d8 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1721,6 +1721,22 @@ class ClassificationProjectRNN(ClassificationProject):
         return x_input
 
 
+    def get_input_flat(self, x):
+        "Transform input back to flat ntuple"
+        nevent = x[0].shape[0]
+        x_flat = np.empty((nevent, len(self.fields)), dtype=np.float)
+        # recurrent fields
+        for rec_ar, idx in zip(x, self.recurrent_field_idx):
+            idx = idx.reshape(-1)
+            for source_idx, target_idx in enumerate(idx):
+                x_flat[:,target_idx] = rec_ar.reshape(nevent, -1)[:,source_idx]
+        # flat fields
+        for source_idx, field_name in enumerate(self.flat_fields):
+            target_idx = self.fields.index(field_name)
+            x_flat[:,target_idx] = x[-1][:,source_idx]
+        return x_flat
+
+
     def yield_batch(self):
         x_train, y_train, w_train = self.training_data
         while True:
-- 
GitLab