diff --git a/toolkit.py b/toolkit.py index 8179bdec2fe0619fe11b96db4e9b485c1353b1f1..ff47669b839462d85a813a47dc79771339e6b5be 100755 --- a/toolkit.py +++ b/toolkit.py @@ -79,6 +79,8 @@ def load_from_dir(path): class ClassificationProject(object): + verbose = 1 # verbosity of the fit method + """Simple framework to load data from ROOT TTrees and train Keras neural networks for classification according to some global settings. @@ -904,7 +906,7 @@ class ClassificationProject(object): np.concatenate((batch_0[2], batch_1[2]))) - def train(self, epochs=10): + def train(self, epochs=10, skip_checkpoint=False): self.load() @@ -918,7 +920,8 @@ class ClassificationProject(object): steps_per_epoch=self.steps_per_epoch, epochs=epochs, validation_data=self.validation_data, - callbacks=self.callbacks_list) + callbacks=self.callbacks_list, + verbose=self.verbose) self.is_training = False except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") @@ -932,12 +935,14 @@ class ClassificationProject(object): steps_per_epoch=int(min(label_counts)/self.batch_size), epochs=epochs, validation_data=self.validation_data, - callbacks=self.callbacks_list) + callbacks=self.callbacks_list, + verbose=self.verbose) self.is_training = False except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") - self.checkpoint_model() + if not skip_checkpoint: + self.checkpoint_model() def checkpoint_model(self):