diff --git a/toolkit.py b/toolkit.py index 5a6aa8a59e1866dd33d0f453ce06777b0b26cfee..524a66f05dbb4fc53e5ab74271855a6f6306c974 100755 --- a/toolkit.py +++ b/toolkit.py @@ -990,6 +990,32 @@ class ClassificationProject(object): self._dump_to_hdf5("scores_train") + def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): + "Calculate scores for training and test sample" + + if mode is not None: + self._write_info("scores_mode", mode) + + def eval_score(data_name): + logger.info("Create/Update scores for {} sample".format(data_name)) + n_events = len(getattr(self, "x_"+data_name)) + setattr(self, "scores_"+data_name, np.empty(n_events)) + for start in range(0, n_events, batch_size): + stop = start+batch_size + getattr(self, "scores_"+data_name)[start:stop] = ( + self.predict( + self.get_input_list(self.transform(getattr(self, "x_"+data_name)[start:stop])), + mode=mode + ).reshape(-1) + ) + self._dump_to_hdf5("scores_"+data_name) + + if do_test: + eval_score("test") + if do_train: + eval_score("train") + + def predict(self, x, mode=None): """ Calculate the scores for a (transformed) array of input values. @@ -1845,31 +1871,6 @@ class ClassificationProjectRNN(ClassificationProject): return x_flat - def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): - - if mode is not None: - self._write_info("scores_mode", mode) - - def eval_score(data_name): - logger.info("Create/Update scores for {} sample".format(data_name)) - n_events = len(getattr(self, "x_"+data_name)) - setattr(self, "scores_"+data_name, np.empty(n_events)) - for start in range(0, n_events, batch_size): - stop = start+batch_size - getattr(self, "scores_"+data_name)[start:stop] = ( - self.predict( - self.get_input_list(getattr(self, "x_"+data_name)[start:stop]), - mode=mode - ).reshape(-1) - ) - self._dump_to_hdf5("scores_"+data_name) - - if do_test: - eval_score("test") - if do_train: - eval_score("train") - - def evaluate(self, x_eval, mode=None): logger.debug("Evaluate score for {}".format(x_eval)) x_eval = self.transform(x_eval)