diff --git a/toolkit.py b/toolkit.py index 834d9774411238f896aa8bb6ffa63fad9a42e3ce..c3040d818f578a21d628df15d2f33bfdf545a9db 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1721,6 +1721,22 @@ class ClassificationProjectRNN(ClassificationProject): return x_input + def get_input_flat(self, x): + "Transform input back to flat ntuple" + nevent = x[0].shape[0] + x_flat = np.empty((nevent, len(self.fields)), dtype=np.float) + # recurrent fields + for rec_ar, idx in zip(x, self.recurrent_field_idx): + idx = idx.reshape(-1) + for source_idx, target_idx in enumerate(idx): + x_flat[:,target_idx] = rec_ar.reshape(nevent, -1)[:,source_idx] + # flat fields + for source_idx, field_name in enumerate(self.flat_fields): + target_idx = self.fields.index(field_name) + x_flat[:,target_idx] = x[-1][:,source_idx] + return x_flat + + def yield_batch(self): x_train, y_train, w_train = self.training_data while True: