From b99f8435e7ae9a8f03450f4e4f8737014697cfb4 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 30 Apr 2018 16:12:24 +0200
Subject: [PATCH] load model weights directly after initialising model,
 otherwise weights are reinitialised

---
 toolkit.py | 21 ++++++++++-----------
 1 file changed, 10 insertions(+), 11 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index ccd68e7..9904ec7 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -332,6 +332,12 @@ class KerasROOTClassification(object):
                                 loss='binary_crossentropy',
                                 metrics=['accuracy'])
 
+            try:
+                self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
+                logger.info("Found and loaded previously trained weights")
+            except IOError:
+                logger.info("No weights found, starting completely new model")
+
             # dump to json for documentation
             with open(os.path.join(self.project_dir, "model.json"), "w") as of:
                 of.write(self._model.to_json())
@@ -376,13 +382,6 @@ class KerasROOTClassification(object):
         for branch_index, branch in enumerate(self.branches):
             self.plot_input(branch_index)
 
-        try:
-            self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
-            logger.info("Weights found and loaded")
-            logger.info("Continue training")
-        except IOError:
-            logger.info("No weights found, starting completely new training")
-
         self.total_epochs = self._read_info("epochs", 0)
 
         logger.info("Train model")
@@ -580,8 +579,8 @@ if __name__ == "__main__":
                                 identifiers = ["DatasetNumber", "EventNumber"],
                                 step_bkg = 100)
 
-    #c.load()
-    c.train(epochs=20)
+    c.load()
+    #c.train(epochs=20)
     c.plot_ROC()
-    c.plot_loss()
-    c.plot_accuracy()
+    # c.plot_loss()
+    # c.plot_accuracy()
-- 
GitLab