diff --git a/toolkit.py b/toolkit.py index ccd68e7717ab5514c37bf4027495ca64373a92dd..9904ec7be590293503af6a515d64c599df07c115 100755 --- a/toolkit.py +++ b/toolkit.py @@ -332,6 +332,12 @@ class KerasROOTClassification(object): loss='binary_crossentropy', metrics=['accuracy']) + try: + self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) + logger.info("Found and loaded previously trained weights") + except IOError: + logger.info("No weights found, starting completely new model") + # dump to json for documentation with open(os.path.join(self.project_dir, "model.json"), "w") as of: of.write(self._model.to_json()) @@ -376,13 +382,6 @@ class KerasROOTClassification(object): for branch_index, branch in enumerate(self.branches): self.plot_input(branch_index) - try: - self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - logger.info("Weights found and loaded") - logger.info("Continue training") - except IOError: - logger.info("No weights found, starting completely new training") - self.total_epochs = self._read_info("epochs", 0) logger.info("Train model") @@ -580,8 +579,8 @@ if __name__ == "__main__": identifiers = ["DatasetNumber", "EventNumber"], step_bkg = 100) - #c.load() - c.train(epochs=20) + c.load() + #c.train(epochs=20) c.plot_ROC() - c.plot_loss() - c.plot_accuracy() + # c.plot_loss() + # c.plot_accuracy()