From 9ed6b0a79d616626d9e99748e64e55dd1248654f Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Tue, 20 Nov 2018 09:01:00 +0100 Subject: [PATCH] fixing evaluate_train_test for regression targets --- toolkit.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/toolkit.py b/toolkit.py index b9edda9..fb51f65 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1119,12 +1119,15 @@ class ClassificationProject(object): setattr(self, "scores_"+data_name, np.empty(n_events)) for start in range(0, n_events, batch_size): stop = start+batch_size - getattr(self, "scores_"+data_name)[start:stop] = ( - self.predict( - self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])), - mode=mode - ).reshape(-1) + outputs = self.predict( + self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])), + mode=mode ) + if not self.target_fields: + scores_batch = outputs.reshape(-1) + else: + scores_batch = outputs[0].reshape(-1) + getattr(self, "scores_"+data_name)[start:stop] = scores_batch self._dump_to_hdf5("scores_"+data_name) if do_test: -- GitLab