Skip to content
Snippets Groups Projects
Commit 9ed6b0a7 authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

fixing evaluate_train_test for regression targets

parent 28f14b38
No related branches found
No related tags found
No related merge requests found
......@@ -1119,12 +1119,15 @@ class ClassificationProject(object):
setattr(self, "scores_"+data_name, np.empty(n_events))
for start in range(0, n_events, batch_size):
stop = start+batch_size
getattr(self, "scores_"+data_name)[start:stop] = (
self.predict(
self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])),
mode=mode
).reshape(-1)
outputs = self.predict(
self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])),
mode=mode
)
if not self.target_fields:
scores_batch = outputs.reshape(-1)
else:
scores_batch = outputs[0].reshape(-1)
getattr(self, "scores_"+data_name)[start:stop] = scores_batch
self._dump_to_hdf5("scores_"+data_name)
if do_test:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment