diff --git a/toolkit.py b/toolkit.py
index 834d9774411238f896aa8bb6ffa63fad9a42e3ce..c3040d818f578a21d628df15d2f33bfdf545a9db 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1721,6 +1721,22 @@ class ClassificationProjectRNN(ClassificationProject):
         return x_input
 
 
+    def get_input_flat(self, x):
+        "Transform input back to flat ntuple"
+        nevent = x[0].shape[0]
+        x_flat = np.empty((nevent, len(self.fields)), dtype=np.float)
+        # recurrent fields
+        for rec_ar, idx in zip(x, self.recurrent_field_idx):
+            idx = idx.reshape(-1)
+            for source_idx, target_idx in enumerate(idx):
+                x_flat[:,target_idx] = rec_ar.reshape(nevent, -1)[:,source_idx]
+        # flat fields
+        for source_idx, field_name in enumerate(self.flat_fields):
+            target_idx = self.fields.index(field_name)
+            x_flat[:,target_idx] = x[-1][:,source_idx]
+        return x_flat
+
+
     def yield_batch(self):
         x_train, y_train, w_train = self.training_data
         while True: