diff --git a/toolkit.py b/toolkit.py index 6da4e4c64beb11b0217e2178537d428f9f3cc72e..f25ccabc1bcf35136d9812deb9cc4e236aa4a220 100755 --- a/toolkit.py +++ b/toolkit.py @@ -852,19 +852,24 @@ class ClassificationProject(object): self.total_epochs += epochs self._write_info("epochs", self.total_epochs) - logger.info("Reloading (and re-transforming) unshuffled training data") - self.load(reload=True) - - logger.info("Create/Update scores for ROC curve") - self.scores_test = self.model.predict(self.x_test) - self.scores_train = self.model.predict(self.x_train) - - self._dump_to_hdf5("scores_train", "scores_test") + self.evaluate_train_test() logger.info("Creating all validation plots") self.plot_all() + def evaluate_train_test(self, do_train=True, do_test=True): + logger.info("Reloading (and re-transforming) unshuffled training data") + self.load(reload=True) + + logger.info("Create/Update scores for train/test sample") + if do_test: + self.scores_test = self.model.predict(self.x_test) + self._dump_to_hdf5("scores_test") + if do_train: + self.scores_train = self.model.predict(self.x_train) + self._dump_to_hdf5("scores_train") + def evaluate(self, x_eval): logger.debug("Evaluate score for {}".format(x_eval))