Skip to content
Snippets Groups Projects
Commit b99f8435 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

load model weights directly after initialising model, otherwise weights are reinitialised

parent 1c7e2228
No related branches found
No related tags found
No related merge requests found
...@@ -332,6 +332,12 @@ class KerasROOTClassification(object): ...@@ -332,6 +332,12 @@ class KerasROOTClassification(object):
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy']) metrics=['accuracy'])
try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
except IOError:
logger.info("No weights found, starting completely new model")
# dump to json for documentation # dump to json for documentation
with open(os.path.join(self.project_dir, "model.json"), "w") as of: with open(os.path.join(self.project_dir, "model.json"), "w") as of:
of.write(self._model.to_json()) of.write(self._model.to_json())
...@@ -376,13 +382,6 @@ class KerasROOTClassification(object): ...@@ -376,13 +382,6 @@ class KerasROOTClassification(object):
for branch_index, branch in enumerate(self.branches): for branch_index, branch in enumerate(self.branches):
self.plot_input(branch_index) self.plot_input(branch_index)
try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Weights found and loaded")
logger.info("Continue training")
except IOError:
logger.info("No weights found, starting completely new training")
self.total_epochs = self._read_info("epochs", 0) self.total_epochs = self._read_info("epochs", 0)
logger.info("Train model") logger.info("Train model")
...@@ -580,8 +579,8 @@ if __name__ == "__main__": ...@@ -580,8 +579,8 @@ if __name__ == "__main__":
identifiers = ["DatasetNumber", "EventNumber"], identifiers = ["DatasetNumber", "EventNumber"],
step_bkg = 100) step_bkg = 100)
#c.load() c.load()
c.train(epochs=20) #c.train(epochs=20)
c.plot_ROC() c.plot_ROC()
c.plot_loss() # c.plot_loss()
c.plot_accuracy() # c.plot_accuracy()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment