diff --git a/toolkit.py b/toolkit.py index b9edda9e43ed12d74fddbe5ca3f880f7106502ef..fb51f65d7c944fe2ee1c10f015f69012bbe71851 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1119,12 +1119,15 @@ class ClassificationProject(object): setattr(self, "scores_"+data_name, np.empty(n_events)) for start in range(0, n_events, batch_size): stop = start+batch_size - getattr(self, "scores_"+data_name)[start:stop] = ( - self.predict( - self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])), - mode=mode - ).reshape(-1) + outputs = self.predict( + self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])), + mode=mode ) + if not self.target_fields: + scores_batch = outputs.reshape(-1) + else: + scores_batch = outputs[0].reshape(-1) + getattr(self, "scores_"+data_name)[start:stop] = scores_batch self._dump_to_hdf5("scores_"+data_name) if do_test: