From f8a5e4436f70db4ac14734007afe8ed30319b44a Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 13 Aug 2018 13:41:45 +0200 Subject: [PATCH] put score evaluation in separate function --- toolkit.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/toolkit.py b/toolkit.py index 6da4e4c..f25ccab 100755 --- a/toolkit.py +++ b/toolkit.py @@ -852,19 +852,24 @@ class ClassificationProject(object): self.total_epochs += epochs self._write_info("epochs", self.total_epochs) - logger.info("Reloading (and re-transforming) unshuffled training data") - self.load(reload=True) - - logger.info("Create/Update scores for ROC curve") - self.scores_test = self.model.predict(self.x_test) - self.scores_train = self.model.predict(self.x_train) - - self._dump_to_hdf5("scores_train", "scores_test") + self.evaluate_train_test() logger.info("Creating all validation plots") self.plot_all() + def evaluate_train_test(self, do_train=True, do_test=True): + logger.info("Reloading (and re-transforming) unshuffled training data") + self.load(reload=True) + + logger.info("Create/Update scores for train/test sample") + if do_test: + self.scores_test = self.model.predict(self.x_test) + self._dump_to_hdf5("scores_test") + if do_train: + self.scores_train = self.model.predict(self.x_train) + self._dump_to_hdf5("scores_train") + def evaluate(self, x_eval): logger.debug("Evaluate score for {}".format(x_eval)) -- GitLab