From dca187ae2b5e49ed2bb498955c5f3871d3c6acf5 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Thu, 1 Nov 2018 14:40:54 +0100 Subject: [PATCH] fixing evaluate_train_test (includes transformation now) --- toolkit.py | 51 ++++++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/toolkit.py b/toolkit.py index 5a6aa8a..524a66f 100755 --- a/toolkit.py +++ b/toolkit.py @@ -990,6 +990,32 @@ class ClassificationProject(object): self._dump_to_hdf5("scores_train") + def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): + "Calculate scores for training and test sample" + + if mode is not None: + self._write_info("scores_mode", mode) + + def eval_score(data_name): + logger.info("Create/Update scores for {} sample".format(data_name)) + n_events = len(getattr(self, "x_"+data_name)) + setattr(self, "scores_"+data_name, np.empty(n_events)) + for start in range(0, n_events, batch_size): + stop = start+batch_size + getattr(self, "scores_"+data_name)[start:stop] = ( + self.predict( + self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])), + mode=mode + ).reshape(-1) + ) + self._dump_to_hdf5("scores_"+data_name) + + if do_test: + eval_score("test") + if do_train: + eval_score("train") + + def predict(self, x, mode=None): """ Calculate the scores for a (transformed) array of input values. @@ -1845,31 +1871,6 @@ class ClassificationProjectRNN(ClassificationProject): return x_flat - def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): - - if mode is not None: - self._write_info("scores_mode", mode) - - def eval_score(data_name): - logger.info("Create/Update scores for {} sample".format(data_name)) - n_events = len(getattr(self, "x_"+data_name)) - setattr(self, "scores_"+data_name, np.empty(n_events)) - for start in range(0, n_events, batch_size): - stop = start+batch_size - getattr(self, "scores_"+data_name)[start:stop] = ( - self.predict( - self.get_input_list(getattr(self, "x_"+data_name)[start:stop]), - mode=mode - ).reshape(-1) - ) - self._dump_to_hdf5("scores_"+data_name) - - if do_test: - eval_score("test") - if do_train: - eval_score("train") - - def evaluate(self, x_eval, mode=None): logger.debug("Evaluate score for {}".format(x_eval)) x_eval = self.transform(x_eval) -- GitLab