From 9ed6b0a79d616626d9e99748e64e55dd1248654f Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Tue, 20 Nov 2018 09:01:00 +0100
Subject: [PATCH] fixing evaluate_train_test for regression targets

---
 toolkit.py | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index b9edda9..fb51f65 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1119,12 +1119,15 @@ class ClassificationProject(object):
             setattr(self, "scores_"+data_name, np.empty(n_events))
             for start in range(0, n_events, batch_size):
                 stop = start+batch_size
-                getattr(self, "scores_"+data_name)[start:stop] = (
-                    self.predict(
-                        self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])),
-                        mode=mode
-                    ).reshape(-1)
+                outputs =  self.predict(
+                    self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])),
+                    mode=mode
                 )
+                if not self.target_fields:
+                    scores_batch = outputs.reshape(-1)
+                else:
+                    scores_batch = outputs[0].reshape(-1)
+                getattr(self, "scores_"+data_name)[start:stop] = scores_batch
             self._dump_to_hdf5("scores_"+data_name)
 
         if do_test:
-- 
GitLab