diff --git a/toolkit.py b/toolkit.py index 01b17e77ba2f28a30a48dbda3c9a36771b07f5e7..dbec934f105a28f16fcb0d20b43f42133310b599 100755 --- a/toolkit.py +++ b/toolkit.py @@ -27,7 +27,7 @@ from sklearn.metrics import roc_curve, auc from keras.models import Sequential from keras.layers import Dense, Dropout from keras.models import model_from_json -from keras.callbacks import History, EarlyStopping, CSVLogger +from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint from keras.optimizers import SGD import keras.optimizers @@ -157,6 +157,7 @@ class ClassificationProject(object): optimizer_opts=None, use_earlystopping=True, earlystopping_opts=None, + use_modelcheckpoint=True, random_seed=1234): self.name = name @@ -178,6 +179,7 @@ class ClassificationProject(object): self.step_bkg = step_bkg self.optimizer = optimizer self.use_earlystopping = use_earlystopping + self.use_modelcheckpoint = use_modelcheckpoint if optimizer_opts is None: optimizer_opts = dict() self.optimizer_opts = optimizer_opts @@ -354,6 +356,10 @@ class ClassificationProject(object): self._callbacks_list.append(self.history) if self.use_earlystopping: self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) + if self.use_modelcheckpoint: + self._callbacks_list.append(ModelCheckpoint(save_best_only=True, + verbose=True, + filepath=os.path.join(self.project_dir, "weights.h5"))) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list @@ -574,8 +580,9 @@ class ClassificationProject(object): logger.info("Save history") self._dump_history() - logger.info("Save weights") - self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) + if not self.use_modelcheckpoint: + logger.info("Save weights") + self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) self.total_epochs += epochs self._write_info("epochs", self.total_epochs)