From 5c7cd191bcd20bc4bae7c229e67d4650ba8fdf70 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Mon, 23 Jul 2018 18:42:28 +0200
Subject: [PATCH] model checkpoint options

option to set modelcheckpoint options
---
 toolkit.py | 43 +++++++++++++++++++++++++++++--------------
 1 file changed, 29 insertions(+), 14 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 3abfb06..0db064a 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -15,6 +15,8 @@ import pickle
 import importlib
 import csv
 import math
+import glob
+import shutil
 
 import logging
 logger = logging.getLogger("KerasROOTClassification")
@@ -134,6 +136,8 @@ class ClassificationProject(object):
 
     :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement
 
+    :param modelcheckpoint_opts: options for the Keras ModelCheckpoint callback
+
     :param balance_dataset: if True, balance the dataset instead of
                             applying class weights. Only a fraction of the overrepresented
                             class will be used in each epoch, but different subsets of the
@@ -159,7 +163,7 @@ class ClassificationProject(object):
         else:
             # otherwise initialise new project
             self._init_from_args(name, *args, **kwargs)
-            with open(os.path.join(self.project_dir, "options.pickle"), "w") as of:
+            with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of:
                 pickle.dump(dict(args=args, kwargs=kwargs), of)
 
 
@@ -169,7 +173,7 @@ class ClassificationProject(object):
             with open(os.path.join(dirname, "options.json")) as f:
                 options = byteify(json.load(f))
         else:
-            with open(os.path.join(dirname, "options.pickle")) as f:
+            with open(os.path.join(dirname, "options.pickle"), "rb") as f:
                 options = pickle.load(f)
         options["kwargs"]["project_dir"] = dirname
         self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"])
@@ -177,6 +181,7 @@ class ClassificationProject(object):
 
     def _init_from_args(self, name,
                         signal_trees, bkg_trees, branches, weight_expr,
+                        project_dir=None,
                         data_dir=None,
                         identifiers=None,
                         selection=None,
@@ -187,7 +192,6 @@ class ClassificationProject(object):
                         validation_split=0.33,
                         activation_function='relu',
                         activation_function_output='sigmoid',
-                        project_dir=None,
                         scaler_type="RobustScaler",
                         step_signal=2,
                         step_bkg=2,
@@ -196,6 +200,7 @@ class ClassificationProject(object):
                         use_earlystopping=True,
                         earlystopping_opts=None,
                         use_modelcheckpoint=True,
+                        modelcheckpoint_opts=None,
                         random_seed=1234,
                         balance_dataset=False):
 
@@ -205,6 +210,14 @@ class ClassificationProject(object):
         self.branches = branches
         self.weight_expr = weight_expr
         self.selection = selection
+
+        self.project_dir = project_dir
+        if self.project_dir is None:
+            self.project_dir = name
+
+        if not os.path.exists(self.project_dir):
+            os.mkdir(self.project_dir)
+
         self.data_dir = data_dir
         if identifiers is None:
             identifiers = []
@@ -228,16 +241,16 @@ class ClassificationProject(object):
         if earlystopping_opts is None:
             earlystopping_opts = dict()
         self.earlystopping_opts = earlystopping_opts
+        if modelcheckpoint_opts is None:
+            modelcheckpoint_opts = dict(
+                save_best_only=True,
+                verbose=True,
+                filepath=os.path.join(self.project_dir, "weights.h5")
+            )
+        self.modelcheckpoint_opts = modelcheckpoint_opts
         self.random_seed = random_seed
         self.balance_dataset = balance_dataset
 
-        self.project_dir = project_dir
-        if self.project_dir is None:
-            self.project_dir = name
-
-        if not os.path.exists(self.project_dir):
-            os.mkdir(self.project_dir)
-
         self.s_train = None
         self.b_train = None
         self.s_test = None
@@ -411,9 +424,7 @@ class ClassificationProject(object):
         if self.use_earlystopping:
             self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
         if self.use_modelcheckpoint:
-            self._callbacks_list.append(ModelCheckpoint(save_best_only=True,
-                                                        verbose=True,
-                                                        filepath=os.path.join(self.project_dir, "weights.h5")))
+            self._callbacks_list.append(ModelCheckpoint(**self.modelcheckpoint_opts))
         self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
         return self._callbacks_list
 
@@ -728,8 +739,12 @@ class ClassificationProject(object):
             logger.info("Save weights")
             self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
         else:
+            weight_file = sorted(glob.glob(os.path.join(self.project_dir, "weights*.h5")), key=lambda f:os.path.getmtime(f))[-1]
+            if not os.path.basename(weight_file) == "weights.h5":
+                logger.info("Copying latest weight file {} to weights.h5".format(weight_file))
+                shutil.copy(weight_file, os.path.join(self.project_dir, "weights.h5"))
+            logger.info("Reloading weights file since we are using model checkpoint!")
             self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
-            logger.info("Reloading weights, since we are using model checkpoint!")
 
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
-- 
GitLab