From f8a5e4436f70db4ac14734007afe8ed30319b44a Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 13 Aug 2018 13:41:45 +0200
Subject: [PATCH] put score evaluation in separate function

---
 toolkit.py | 21 +++++++++++++--------
 1 file changed, 13 insertions(+), 8 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 6da4e4c..f25ccab 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -852,19 +852,24 @@ class ClassificationProject(object):
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
 
-        logger.info("Reloading (and re-transforming) unshuffled training data")
-        self.load(reload=True)
-
-        logger.info("Create/Update scores for ROC curve")
-        self.scores_test = self.model.predict(self.x_test)
-        self.scores_train = self.model.predict(self.x_train)
-
-        self._dump_to_hdf5("scores_train", "scores_test")
+        self.evaluate_train_test()
 
         logger.info("Creating all validation plots")
         self.plot_all()
 
 
+    def evaluate_train_test(self, do_train=True, do_test=True):
+        logger.info("Reloading (and re-transforming) unshuffled training data")
+        self.load(reload=True)
+
+        logger.info("Create/Update scores for train/test sample")
+        if do_test:
+            self.scores_test = self.model.predict(self.x_test)
+            self._dump_to_hdf5("scores_test")
+        if do_train:
+            self.scores_train = self.model.predict(self.x_train)
+            self._dump_to_hdf5("scores_train")
+
 
     def evaluate(self, x_eval):
         logger.debug("Evaluate score for {}".format(x_eval))
-- 
GitLab