Skip to content
Snippets Groups Projects
Commit a7a38241 authored by Nikolai's avatar Nikolai
Browse files

Query if model should be retrained

parent 4f5d19bd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
from sys import version_info
if version_info[0] > 2:
raw_input = input
import os
import json
import pickle
......@@ -41,6 +46,7 @@ K.set_session(session)
import ROOT
class ClassificationProject(object):
"""Simple framework to load data from ROOT TTrees and train Keras
......@@ -434,6 +440,19 @@ class ClassificationProject(object):
json.dump(info, of)
@staticmethod
def query_yn(text):
result = None
while result is None:
input_text = raw_input(text)
if len(input_text) > 0:
if input_text.upper()[0] == "Y":
result = True
elif input_text.upper()[0] == "N":
result = False
return result
@property
def model(self):
"Simple MLP"
......@@ -461,10 +480,14 @@ class ClassificationProject(object):
loss='binary_crossentropy',
metrics=['accuracy'])
np.random.set_state(rn_state)
try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
except IOError:
if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ")
if continue_training:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
else:
logger.info("Starting completely new model")
else:
logger.info("No weights found, starting completely new model")
# dump to json for documentation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment