Skip to content
Snippets Groups Projects
Commit d665c504 authored by Nikolai's avatar Nikolai
Browse files

ensure that the model weights end up in project_dir

parent 5c7cd191
No related branches found
No related tags found
No related merge requests found
......@@ -134,9 +134,12 @@ class ClassificationProject(object):
:param earlystopping_opts: options for the keras EarlyStopping callback
:param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement
:param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement (except if the options are set otherwise).
:param modelcheckpoint_opts: options for the Keras ModelCheckpoint callback
:param modelcheckpoint_opts: options for the Keras ModelCheckpoint
callback. After training, the newest saved weight will be used. If
you change the format of the saved model weights it has to be of
the form "weights*.h5"
:param balance_dataset: if True, balance the dataset instead of
applying class weights. Only a fraction of the overrepresented
......@@ -245,7 +248,7 @@ class ClassificationProject(object):
modelcheckpoint_opts = dict(
save_best_only=True,
verbose=True,
filepath=os.path.join(self.project_dir, "weights.h5")
filepath="weights.h5"
)
self.modelcheckpoint_opts = modelcheckpoint_opts
self.random_seed = random_seed
......@@ -424,7 +427,11 @@ class ClassificationProject(object):
if self.use_earlystopping:
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
if self.use_modelcheckpoint:
self._callbacks_list.append(ModelCheckpoint(**self.modelcheckpoint_opts))
mc = ModelCheckpoint(**self.modelcheckpoint_opts)
self._callbacks_list.append(mc)
if not os.path.dirname(mc.filepath) == self.project_dir:
mc.filepath = os.path.join(self.project_dir, mc.filepath)
logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath))
self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
return self._callbacks_list
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment