Skip to content
Snippets Groups Projects
Commit 5c7cd191 authored by Nikolai's avatar Nikolai
Browse files

model checkpoint options

option to set modelcheckpoint options
parent 6300dafe
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,8 @@ import pickle
import importlib
import csv
import math
import glob
import shutil
import logging
logger = logging.getLogger("KerasROOTClassification")
......@@ -134,6 +136,8 @@ class ClassificationProject(object):
:param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement
:param modelcheckpoint_opts: options for the Keras ModelCheckpoint callback
:param balance_dataset: if True, balance the dataset instead of
applying class weights. Only a fraction of the overrepresented
class will be used in each epoch, but different subsets of the
......@@ -159,7 +163,7 @@ class ClassificationProject(object):
else:
# otherwise initialise new project
self._init_from_args(name, *args, **kwargs)
with open(os.path.join(self.project_dir, "options.pickle"), "w") as of:
with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of:
pickle.dump(dict(args=args, kwargs=kwargs), of)
......@@ -169,7 +173,7 @@ class ClassificationProject(object):
with open(os.path.join(dirname, "options.json")) as f:
options = byteify(json.load(f))
else:
with open(os.path.join(dirname, "options.pickle")) as f:
with open(os.path.join(dirname, "options.pickle"), "rb") as f:
options = pickle.load(f)
options["kwargs"]["project_dir"] = dirname
self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"])
......@@ -177,6 +181,7 @@ class ClassificationProject(object):
def _init_from_args(self, name,
signal_trees, bkg_trees, branches, weight_expr,
project_dir=None,
data_dir=None,
identifiers=None,
selection=None,
......@@ -187,7 +192,6 @@ class ClassificationProject(object):
validation_split=0.33,
activation_function='relu',
activation_function_output='sigmoid',
project_dir=None,
scaler_type="RobustScaler",
step_signal=2,
step_bkg=2,
......@@ -196,6 +200,7 @@ class ClassificationProject(object):
use_earlystopping=True,
earlystopping_opts=None,
use_modelcheckpoint=True,
modelcheckpoint_opts=None,
random_seed=1234,
balance_dataset=False):
......@@ -205,6 +210,14 @@ class ClassificationProject(object):
self.branches = branches
self.weight_expr = weight_expr
self.selection = selection
self.project_dir = project_dir
if self.project_dir is None:
self.project_dir = name
if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir)
self.data_dir = data_dir
if identifiers is None:
identifiers = []
......@@ -228,16 +241,16 @@ class ClassificationProject(object):
if earlystopping_opts is None:
earlystopping_opts = dict()
self.earlystopping_opts = earlystopping_opts
if modelcheckpoint_opts is None:
modelcheckpoint_opts = dict(
save_best_only=True,
verbose=True,
filepath=os.path.join(self.project_dir, "weights.h5")
)
self.modelcheckpoint_opts = modelcheckpoint_opts
self.random_seed = random_seed
self.balance_dataset = balance_dataset
self.project_dir = project_dir
if self.project_dir is None:
self.project_dir = name
if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir)
self.s_train = None
self.b_train = None
self.s_test = None
......@@ -411,9 +424,7 @@ class ClassificationProject(object):
if self.use_earlystopping:
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
if self.use_modelcheckpoint:
self._callbacks_list.append(ModelCheckpoint(save_best_only=True,
verbose=True,
filepath=os.path.join(self.project_dir, "weights.h5")))
self._callbacks_list.append(ModelCheckpoint(**self.modelcheckpoint_opts))
self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
return self._callbacks_list
......@@ -728,8 +739,12 @@ class ClassificationProject(object):
logger.info("Save weights")
self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
else:
weight_file = sorted(glob.glob(os.path.join(self.project_dir, "weights*.h5")), key=lambda f:os.path.getmtime(f))[-1]
if not os.path.basename(weight_file) == "weights.h5":
logger.info("Copying latest weight file {} to weights.h5".format(weight_file))
shutil.copy(weight_file, os.path.join(self.project_dir, "weights.h5"))
logger.info("Reloading weights file since we are using model checkpoint!")
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Reloading weights, since we are using model checkpoint!")
self.total_epochs += epochs
self._write_info("epochs", self.total_epochs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment