Skip to content
Snippets Groups Projects
Commit a7a38241 authored by Nikolai's avatar Nikolai
Browse files

Query if model should be retrained

parent 4f5d19bd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python #!/usr/bin/env python
from sys import version_info
if version_info[0] > 2:
raw_input = input
import os import os
import json import json
import pickle import pickle
...@@ -41,6 +46,7 @@ K.set_session(session) ...@@ -41,6 +46,7 @@ K.set_session(session)
import ROOT import ROOT
class ClassificationProject(object): class ClassificationProject(object):
"""Simple framework to load data from ROOT TTrees and train Keras """Simple framework to load data from ROOT TTrees and train Keras
...@@ -434,6 +440,19 @@ class ClassificationProject(object): ...@@ -434,6 +440,19 @@ class ClassificationProject(object):
json.dump(info, of) json.dump(info, of)
@staticmethod
def query_yn(text):
result = None
while result is None:
input_text = raw_input(text)
if len(input_text) > 0:
if input_text.upper()[0] == "Y":
result = True
elif input_text.upper()[0] == "N":
result = False
return result
@property @property
def model(self): def model(self):
"Simple MLP" "Simple MLP"
...@@ -461,10 +480,14 @@ class ClassificationProject(object): ...@@ -461,10 +480,14 @@ class ClassificationProject(object):
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy']) metrics=['accuracy'])
np.random.set_state(rn_state) np.random.set_state(rn_state)
try: if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ")
logger.info("Found and loaded previously trained weights") if continue_training:
except IOError: self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
else:
logger.info("Starting completely new model")
else:
logger.info("No weights found, starting completely new model") logger.info("No weights found, starting completely new model")
# dump to json for documentation # dump to json for documentation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment