From afa8126395d001a0b916296edce316faa3d62c74 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Wed, 12 Sep 2018 14:41:22 +0200
Subject: [PATCH] option to skip checkpoint and verbosity for fit method

---
 toolkit.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 8179bde..ff47669 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -79,6 +79,8 @@ def load_from_dir(path):
 
 class ClassificationProject(object):
 
+    verbose = 1 # verbosity of the fit method
+
     """Simple framework to load data from ROOT TTrees and train Keras
     neural networks for classification according to some global settings.
 
@@ -904,7 +906,7 @@ class ClassificationProject(object):
                    np.concatenate((batch_0[2], batch_1[2])))
 
 
-    def train(self, epochs=10):
+    def train(self, epochs=10, skip_checkpoint=False):
 
         self.load()
 
@@ -918,7 +920,8 @@ class ClassificationProject(object):
                                          steps_per_epoch=self.steps_per_epoch,
                                          epochs=epochs,
                                          validation_data=self.validation_data,
-                                         callbacks=self.callbacks_list)
+                                         callbacks=self.callbacks_list,
+                                         verbose=self.verbose)
                 self.is_training = False
             except KeyboardInterrupt:
                 logger.info("Interrupt training - continue with rest")
@@ -932,12 +935,14 @@ class ClassificationProject(object):
                                          steps_per_epoch=int(min(label_counts)/self.batch_size),
                                          epochs=epochs,
                                          validation_data=self.validation_data,
-                                         callbacks=self.callbacks_list)
+                                         callbacks=self.callbacks_list,
+                                         verbose=self.verbose)
                 self.is_training = False
             except KeyboardInterrupt:
                 logger.info("Interrupt training - continue with rest")
 
-        self.checkpoint_model()
+        if not skip_checkpoint:
+            self.checkpoint_model()
 
 
     def checkpoint_model(self):
-- 
GitLab