From 5c7cd191bcd20bc4bae7c229e67d4650ba8fdf70 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Mon, 23 Jul 2018 18:42:28 +0200 Subject: [PATCH] model checkpoint options option to set modelcheckpoint options --- toolkit.py | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/toolkit.py b/toolkit.py index 3abfb06..0db064a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -15,6 +15,8 @@ import pickle import importlib import csv import math +import glob +import shutil import logging logger = logging.getLogger("KerasROOTClassification") @@ -134,6 +136,8 @@ class ClassificationProject(object): :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement + :param modelcheckpoint_opts: options for the Keras ModelCheckpoint callback + :param balance_dataset: if True, balance the dataset instead of applying class weights. Only a fraction of the overrepresented class will be used in each epoch, but different subsets of the @@ -159,7 +163,7 @@ class ClassificationProject(object): else: # otherwise initialise new project self._init_from_args(name, *args, **kwargs) - with open(os.path.join(self.project_dir, "options.pickle"), "w") as of: + with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of: pickle.dump(dict(args=args, kwargs=kwargs), of) @@ -169,7 +173,7 @@ class ClassificationProject(object): with open(os.path.join(dirname, "options.json")) as f: options = byteify(json.load(f)) else: - with open(os.path.join(dirname, "options.pickle")) as f: + with open(os.path.join(dirname, "options.pickle"), "rb") as f: options = pickle.load(f) options["kwargs"]["project_dir"] = dirname self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"]) @@ -177,6 +181,7 @@ class ClassificationProject(object): def _init_from_args(self, name, signal_trees, bkg_trees, branches, weight_expr, + project_dir=None, data_dir=None, identifiers=None, selection=None, @@ -187,7 +192,6 @@ class ClassificationProject(object): validation_split=0.33, activation_function='relu', activation_function_output='sigmoid', - project_dir=None, scaler_type="RobustScaler", step_signal=2, step_bkg=2, @@ -196,6 +200,7 @@ class ClassificationProject(object): use_earlystopping=True, earlystopping_opts=None, use_modelcheckpoint=True, + modelcheckpoint_opts=None, random_seed=1234, balance_dataset=False): @@ -205,6 +210,14 @@ class ClassificationProject(object): self.branches = branches self.weight_expr = weight_expr self.selection = selection + + self.project_dir = project_dir + if self.project_dir is None: + self.project_dir = name + + if not os.path.exists(self.project_dir): + os.mkdir(self.project_dir) + self.data_dir = data_dir if identifiers is None: identifiers = [] @@ -228,16 +241,16 @@ class ClassificationProject(object): if earlystopping_opts is None: earlystopping_opts = dict() self.earlystopping_opts = earlystopping_opts + if modelcheckpoint_opts is None: + modelcheckpoint_opts = dict( + save_best_only=True, + verbose=True, + filepath=os.path.join(self.project_dir, "weights.h5") + ) + self.modelcheckpoint_opts = modelcheckpoint_opts self.random_seed = random_seed self.balance_dataset = balance_dataset - self.project_dir = project_dir - if self.project_dir is None: - self.project_dir = name - - if not os.path.exists(self.project_dir): - os.mkdir(self.project_dir) - self.s_train = None self.b_train = None self.s_test = None @@ -411,9 +424,7 @@ class ClassificationProject(object): if self.use_earlystopping: self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) if self.use_modelcheckpoint: - self._callbacks_list.append(ModelCheckpoint(save_best_only=True, - verbose=True, - filepath=os.path.join(self.project_dir, "weights.h5"))) + self._callbacks_list.append(ModelCheckpoint(**self.modelcheckpoint_opts)) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list @@ -728,8 +739,12 @@ class ClassificationProject(object): logger.info("Save weights") self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) else: + weight_file = sorted(glob.glob(os.path.join(self.project_dir, "weights*.h5")), key=lambda f:os.path.getmtime(f))[-1] + if not os.path.basename(weight_file) == "weights.h5": + logger.info("Copying latest weight file {} to weights.h5".format(weight_file)) + shutil.copy(weight_file, os.path.join(self.project_dir, "weights.h5")) + logger.info("Reloading weights file since we are using model checkpoint!") self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - logger.info("Reloading weights, since we are using model checkpoint!") self.total_epochs += epochs self._write_info("epochs", self.total_epochs) -- GitLab