From 4bd7de4a6f04f60f3c36d6a2f93704f6447e8402 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Wed, 16 May 2018 11:49:26 +0200 Subject: [PATCH] use ModelCheckpoint callback --- toolkit.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/toolkit.py b/toolkit.py index 01b17e7..dbec934 100755 --- a/toolkit.py +++ b/toolkit.py @@ -27,7 +27,7 @@ from sklearn.metrics import roc_curve, auc from keras.models import Sequential from keras.layers import Dense, Dropout from keras.models import model_from_json -from keras.callbacks import History, EarlyStopping, CSVLogger +from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint from keras.optimizers import SGD import keras.optimizers @@ -157,6 +157,7 @@ class ClassificationProject(object): optimizer_opts=None, use_earlystopping=True, earlystopping_opts=None, + use_modelcheckpoint=True, random_seed=1234): self.name = name @@ -178,6 +179,7 @@ class ClassificationProject(object): self.step_bkg = step_bkg self.optimizer = optimizer self.use_earlystopping = use_earlystopping + self.use_modelcheckpoint = use_modelcheckpoint if optimizer_opts is None: optimizer_opts = dict() self.optimizer_opts = optimizer_opts @@ -354,6 +356,10 @@ class ClassificationProject(object): self._callbacks_list.append(self.history) if self.use_earlystopping: self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) + if self.use_modelcheckpoint: + self._callbacks_list.append(ModelCheckpoint(save_best_only=True, + verbose=True, + filepath=os.path.join(self.project_dir, "weights.h5"))) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list @@ -574,8 +580,9 @@ class ClassificationProject(object): logger.info("Save history") self._dump_history() - logger.info("Save weights") - self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) + if not self.use_modelcheckpoint: + logger.info("Save weights") + self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) self.total_epochs += epochs self._write_info("epochs", self.total_epochs) -- GitLab