diff --git a/toolkit.py b/toolkit.py index d430d1404f60ad7d1a7008a0f98738a9a3e42129..3d45c6e591749ecbfc5a619bf933d635d558f921 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +from sys import version_info + +if version_info[0] > 2: + raw_input = input + import os import json import pickle @@ -41,6 +46,7 @@ K.set_session(session) import ROOT + class ClassificationProject(object): """Simple framework to load data from ROOT TTrees and train Keras @@ -434,6 +440,19 @@ class ClassificationProject(object): json.dump(info, of) + @staticmethod + def query_yn(text): + result = None + while result is None: + input_text = raw_input(text) + if len(input_text) > 0: + if input_text.upper()[0] == "Y": + result = True + elif input_text.upper()[0] == "N": + result = False + return result + + @property def model(self): "Simple MLP" @@ -461,10 +480,14 @@ class ClassificationProject(object): loss='binary_crossentropy', metrics=['accuracy']) np.random.set_state(rn_state) - try: - self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - logger.info("Found and loaded previously trained weights") - except IOError: + if os.path.exists(os.path.join(self.project_dir, "weights.h5")): + continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ") + if continue_training: + self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) + logger.info("Found and loaded previously trained weights") + else: + logger.info("Starting completely new model") + else: logger.info("No weights found, starting completely new model") # dump to json for documentation