From b99f8435e7ae9a8f03450f4e4f8737014697cfb4 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 30 Apr 2018 16:12:24 +0200 Subject: [PATCH] load model weights directly after initialising model, otherwise weights are reinitialised --- toolkit.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/toolkit.py b/toolkit.py index ccd68e7..9904ec7 100755 --- a/toolkit.py +++ b/toolkit.py @@ -332,6 +332,12 @@ class KerasROOTClassification(object): loss='binary_crossentropy', metrics=['accuracy']) + try: + self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) + logger.info("Found and loaded previously trained weights") + except IOError: + logger.info("No weights found, starting completely new model") + # dump to json for documentation with open(os.path.join(self.project_dir, "model.json"), "w") as of: of.write(self._model.to_json()) @@ -376,13 +382,6 @@ class KerasROOTClassification(object): for branch_index, branch in enumerate(self.branches): self.plot_input(branch_index) - try: - self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - logger.info("Weights found and loaded") - logger.info("Continue training") - except IOError: - logger.info("No weights found, starting completely new training") - self.total_epochs = self._read_info("epochs", 0) logger.info("Train model") @@ -580,8 +579,8 @@ if __name__ == "__main__": identifiers = ["DatasetNumber", "EventNumber"], step_bkg = 100) - #c.load() - c.train(epochs=20) + c.load() + #c.train(epochs=20) c.plot_ROC() - c.plot_loss() - c.plot_accuracy() + # c.plot_loss() + # c.plot_accuracy() -- GitLab