From d665c504a592762c5073e4e0c882238915595774 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Tue, 24 Jul 2018 09:46:16 +0200 Subject: [PATCH] ensure that the model weights end up in project_dir --- toolkit.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/toolkit.py b/toolkit.py index 0db064a..21dda44 100755 --- a/toolkit.py +++ b/toolkit.py @@ -134,9 +134,12 @@ class ClassificationProject(object): :param earlystopping_opts: options for the keras EarlyStopping callback - :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement + :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement (except if the options are set otherwise). - :param modelcheckpoint_opts: options for the Keras ModelCheckpoint callback + :param modelcheckpoint_opts: options for the Keras ModelCheckpoint + callback. After training, the newest saved weight will be used. If + you change the format of the saved model weights it has to be of + the form "weights*.h5" :param balance_dataset: if True, balance the dataset instead of applying class weights. Only a fraction of the overrepresented @@ -245,7 +248,7 @@ class ClassificationProject(object): modelcheckpoint_opts = dict( save_best_only=True, verbose=True, - filepath=os.path.join(self.project_dir, "weights.h5") + filepath="weights.h5" ) self.modelcheckpoint_opts = modelcheckpoint_opts self.random_seed = random_seed @@ -424,7 +427,11 @@ class ClassificationProject(object): if self.use_earlystopping: self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) if self.use_modelcheckpoint: - self._callbacks_list.append(ModelCheckpoint(**self.modelcheckpoint_opts)) + mc = ModelCheckpoint(**self.modelcheckpoint_opts) + self._callbacks_list.append(mc) + if not os.path.dirname(mc.filepath) == self.project_dir: + mc.filepath = os.path.join(self.project_dir, mc.filepath) + logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath)) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list -- GitLab