From a7a38241e695e12761ab50ba5fc7c53eb5e4d49e Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Wed, 9 May 2018 10:10:40 +0200 Subject: [PATCH] Query if model should be retrained --- toolkit.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/toolkit.py b/toolkit.py index d430d14..3d45c6e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +from sys import version_info + +if version_info[0] > 2: + raw_input = input + import os import json import pickle @@ -41,6 +46,7 @@ K.set_session(session) import ROOT + class ClassificationProject(object): """Simple framework to load data from ROOT TTrees and train Keras @@ -434,6 +440,19 @@ class ClassificationProject(object): json.dump(info, of) + @staticmethod + def query_yn(text): + result = None + while result is None: + input_text = raw_input(text) + if len(input_text) > 0: + if input_text.upper()[0] == "Y": + result = True + elif input_text.upper()[0] == "N": + result = False + return result + + @property def model(self): "Simple MLP" @@ -461,10 +480,14 @@ class ClassificationProject(object): loss='binary_crossentropy', metrics=['accuracy']) np.random.set_state(rn_state) - try: - self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - logger.info("Found and loaded previously trained weights") - except IOError: + if os.path.exists(os.path.join(self.project_dir, "weights.h5")): + continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ") + if continue_training: + self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) + logger.info("Found and loaded previously trained weights") + else: + logger.info("Starting completely new model") + else: logger.info("No weights found, starting completely new model") # dump to json for documentation -- GitLab