diff --git a/toolkit.py b/toolkit.py
index 6da4e4c64beb11b0217e2178537d428f9f3cc72e..f25ccabc1bcf35136d9812deb9cc4e236aa4a220 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -852,19 +852,24 @@ class ClassificationProject(object):
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
 
-        logger.info("Reloading (and re-transforming) unshuffled training data")
-        self.load(reload=True)
-
-        logger.info("Create/Update scores for ROC curve")
-        self.scores_test = self.model.predict(self.x_test)
-        self.scores_train = self.model.predict(self.x_train)
-
-        self._dump_to_hdf5("scores_train", "scores_test")
+        self.evaluate_train_test()
 
         logger.info("Creating all validation plots")
         self.plot_all()
 
 
+    def evaluate_train_test(self, do_train=True, do_test=True):
+        logger.info("Reloading (and re-transforming) unshuffled training data")
+        self.load(reload=True)
+
+        logger.info("Create/Update scores for train/test sample")
+        if do_test:
+            self.scores_test = self.model.predict(self.x_test)
+            self._dump_to_hdf5("scores_test")
+        if do_train:
+            self.scores_train = self.model.predict(self.x_train)
+            self._dump_to_hdf5("scores_train")
+
 
     def evaluate(self, x_eval):
         logger.debug("Evaluate score for {}".format(x_eval))