From a7a38241e695e12761ab50ba5fc7c53eb5e4d49e Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Wed, 9 May 2018 10:10:40 +0200
Subject: [PATCH] Query if model should be retrained

---
 toolkit.py | 31 +++++++++++++++++++++++++++----
 1 file changed, 27 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index d430d14..3d45c6e 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,5 +1,10 @@
 #!/usr/bin/env python
 
+from sys import version_info
+
+if version_info[0] > 2:
+    raw_input = input
+
 import os
 import json
 import pickle
@@ -41,6 +46,7 @@ K.set_session(session)
 
 import ROOT
 
+
 class ClassificationProject(object):
 
     """Simple framework to load data from ROOT TTrees and train Keras
@@ -434,6 +440,19 @@ class ClassificationProject(object):
             json.dump(info, of)
 
 
+    @staticmethod
+    def query_yn(text):
+        result = None
+        while result is None:
+            input_text = raw_input(text)
+            if len(input_text) > 0:
+                if input_text.upper()[0] == "Y":
+                    result = True
+                elif input_text.upper()[0] == "N":
+                    result = False
+        return result
+
+
     @property
     def model(self):
         "Simple MLP"
@@ -461,10 +480,14 @@ class ClassificationProject(object):
                                 loss='binary_crossentropy',
                                 metrics=['accuracy'])
             np.random.set_state(rn_state)
-            try:
-                self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
-                logger.info("Found and loaded previously trained weights")
-            except IOError:
+            if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
+                continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ")
+                if continue_training:
+                    self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
+                    logger.info("Found and loaded previously trained weights")
+                else:
+                    logger.info("Starting completely new model")
+            else:
                 logger.info("No weights found, starting completely new model")
 
             # dump to json for documentation
-- 
GitLab