Skip to content
Snippets Groups Projects
Commit 4bd7de4a authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

use ModelCheckpoint callback

parent 487774d8
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.models import model_from_json
from keras.callbacks import History, EarlyStopping, CSVLogger
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint
from keras.optimizers import SGD
import keras.optimizers
......@@ -157,6 +157,7 @@ class ClassificationProject(object):
optimizer_opts=None,
use_earlystopping=True,
earlystopping_opts=None,
use_modelcheckpoint=True,
random_seed=1234):
self.name = name
......@@ -178,6 +179,7 @@ class ClassificationProject(object):
self.step_bkg = step_bkg
self.optimizer = optimizer
self.use_earlystopping = use_earlystopping
self.use_modelcheckpoint = use_modelcheckpoint
if optimizer_opts is None:
optimizer_opts = dict()
self.optimizer_opts = optimizer_opts
......@@ -354,6 +356,10 @@ class ClassificationProject(object):
self._callbacks_list.append(self.history)
if self.use_earlystopping:
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
if self.use_modelcheckpoint:
self._callbacks_list.append(ModelCheckpoint(save_best_only=True,
verbose=True,
filepath=os.path.join(self.project_dir, "weights.h5")))
self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
return self._callbacks_list
......@@ -574,8 +580,9 @@ class ClassificationProject(object):
logger.info("Save history")
self._dump_history()
logger.info("Save weights")
self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
if not self.use_modelcheckpoint:
logger.info("Save weights")
self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
self.total_epochs += epochs
self._write_info("epochs", self.total_epochs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment