From d665c504a592762c5073e4e0c882238915595774 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 24 Jul 2018 09:46:16 +0200
Subject: [PATCH] ensure that the model weights end up in project_dir

---
 toolkit.py | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 0db064a..21dda44 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -134,9 +134,12 @@ class ClassificationProject(object):
 
     :param earlystopping_opts: options for the keras EarlyStopping callback
 
-    :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement
+    :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement (except if the options are set otherwise).
 
-    :param modelcheckpoint_opts: options for the Keras ModelCheckpoint callback
+    :param modelcheckpoint_opts: options for the Keras ModelCheckpoint
+                                 callback. After training, the newest saved weight will be used. If
+                                 you change the format of the saved model weights it has to be of
+                                 the form "weights*.h5"
 
     :param balance_dataset: if True, balance the dataset instead of
                             applying class weights. Only a fraction of the overrepresented
@@ -245,7 +248,7 @@ class ClassificationProject(object):
             modelcheckpoint_opts = dict(
                 save_best_only=True,
                 verbose=True,
-                filepath=os.path.join(self.project_dir, "weights.h5")
+                filepath="weights.h5"
             )
         self.modelcheckpoint_opts = modelcheckpoint_opts
         self.random_seed = random_seed
@@ -424,7 +427,11 @@ class ClassificationProject(object):
         if self.use_earlystopping:
             self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
         if self.use_modelcheckpoint:
-            self._callbacks_list.append(ModelCheckpoint(**self.modelcheckpoint_opts))
+            mc = ModelCheckpoint(**self.modelcheckpoint_opts)
+            self._callbacks_list.append(mc)
+            if not os.path.dirname(mc.filepath) == self.project_dir:
+                mc.filepath = os.path.join(self.project_dir, mc.filepath)
+                logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath))
         self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
         return self._callbacks_list
 
-- 
GitLab