Skip to content
Snippets Groups Projects
Commit 9b68e0b4 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Merge remote-tracking branch 'origin/master'

parents fef12369 d665c504
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,8 @@ import pickle ...@@ -15,6 +15,8 @@ import pickle
import importlib import importlib
import csv import csv
import math import math
import glob
import shutil
import logging import logging
logger = logging.getLogger("KerasROOTClassification") logger = logging.getLogger("KerasROOTClassification")
...@@ -132,7 +134,12 @@ class ClassificationProject(object): ...@@ -132,7 +134,12 @@ class ClassificationProject(object):
:param earlystopping_opts: options for the keras EarlyStopping callback :param earlystopping_opts: options for the keras EarlyStopping callback
:param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement (except if the options are set otherwise).
:param modelcheckpoint_opts: options for the Keras ModelCheckpoint
callback. After training, the newest saved weight will be used. If
you change the format of the saved model weights it has to be of
the form "weights*.h5"
:param balance_dataset: if True, balance the dataset instead of :param balance_dataset: if True, balance the dataset instead of
applying class weights. Only a fraction of the overrepresented applying class weights. Only a fraction of the overrepresented
...@@ -159,7 +166,7 @@ class ClassificationProject(object): ...@@ -159,7 +166,7 @@ class ClassificationProject(object):
else: else:
# otherwise initialise new project # otherwise initialise new project
self._init_from_args(name, *args, **kwargs) self._init_from_args(name, *args, **kwargs)
with open(os.path.join(self.project_dir, "options.pickle"), "w") as of: with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of:
pickle.dump(dict(args=args, kwargs=kwargs), of) pickle.dump(dict(args=args, kwargs=kwargs), of)
...@@ -169,7 +176,7 @@ class ClassificationProject(object): ...@@ -169,7 +176,7 @@ class ClassificationProject(object):
with open(os.path.join(dirname, "options.json")) as f: with open(os.path.join(dirname, "options.json")) as f:
options = byteify(json.load(f)) options = byteify(json.load(f))
else: else:
with open(os.path.join(dirname, "options.pickle")) as f: with open(os.path.join(dirname, "options.pickle"), "rb") as f:
options = pickle.load(f) options = pickle.load(f)
options["kwargs"]["project_dir"] = dirname options["kwargs"]["project_dir"] = dirname
self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"]) self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"])
...@@ -177,6 +184,7 @@ class ClassificationProject(object): ...@@ -177,6 +184,7 @@ class ClassificationProject(object):
def _init_from_args(self, name, def _init_from_args(self, name,
signal_trees, bkg_trees, branches, weight_expr, signal_trees, bkg_trees, branches, weight_expr,
project_dir=None,
data_dir=None, data_dir=None,
identifiers=None, identifiers=None,
selection=None, selection=None,
...@@ -187,7 +195,6 @@ class ClassificationProject(object): ...@@ -187,7 +195,6 @@ class ClassificationProject(object):
validation_split=0.33, validation_split=0.33,
activation_function='relu', activation_function='relu',
activation_function_output='sigmoid', activation_function_output='sigmoid',
project_dir=None,
scaler_type="RobustScaler", scaler_type="RobustScaler",
step_signal=2, step_signal=2,
step_bkg=2, step_bkg=2,
...@@ -196,6 +203,7 @@ class ClassificationProject(object): ...@@ -196,6 +203,7 @@ class ClassificationProject(object):
use_earlystopping=True, use_earlystopping=True,
earlystopping_opts=None, earlystopping_opts=None,
use_modelcheckpoint=True, use_modelcheckpoint=True,
modelcheckpoint_opts=None,
random_seed=1234, random_seed=1234,
balance_dataset=False): balance_dataset=False):
...@@ -205,6 +213,14 @@ class ClassificationProject(object): ...@@ -205,6 +213,14 @@ class ClassificationProject(object):
self.branches = branches self.branches = branches
self.weight_expr = weight_expr self.weight_expr = weight_expr
self.selection = selection self.selection = selection
self.project_dir = project_dir
if self.project_dir is None:
self.project_dir = name
if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir)
self.data_dir = data_dir self.data_dir = data_dir
if identifiers is None: if identifiers is None:
identifiers = [] identifiers = []
...@@ -228,16 +244,16 @@ class ClassificationProject(object): ...@@ -228,16 +244,16 @@ class ClassificationProject(object):
if earlystopping_opts is None: if earlystopping_opts is None:
earlystopping_opts = dict() earlystopping_opts = dict()
self.earlystopping_opts = earlystopping_opts self.earlystopping_opts = earlystopping_opts
if modelcheckpoint_opts is None:
modelcheckpoint_opts = dict(
save_best_only=True,
verbose=True,
filepath="weights.h5"
)
self.modelcheckpoint_opts = modelcheckpoint_opts
self.random_seed = random_seed self.random_seed = random_seed
self.balance_dataset = balance_dataset self.balance_dataset = balance_dataset
self.project_dir = project_dir
if self.project_dir is None:
self.project_dir = name
if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir)
self.s_train = None self.s_train = None
self.b_train = None self.b_train = None
self.s_test = None self.s_test = None
...@@ -411,9 +427,11 @@ class ClassificationProject(object): ...@@ -411,9 +427,11 @@ class ClassificationProject(object):
if self.use_earlystopping: if self.use_earlystopping:
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
if self.use_modelcheckpoint: if self.use_modelcheckpoint:
self._callbacks_list.append(ModelCheckpoint(save_best_only=True, mc = ModelCheckpoint(**self.modelcheckpoint_opts)
verbose=True, self._callbacks_list.append(mc)
filepath=os.path.join(self.project_dir, "weights.h5"))) if not os.path.dirname(mc.filepath) == self.project_dir:
mc.filepath = os.path.join(self.project_dir, mc.filepath)
logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath))
self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
return self._callbacks_list return self._callbacks_list
...@@ -728,8 +746,12 @@ class ClassificationProject(object): ...@@ -728,8 +746,12 @@ class ClassificationProject(object):
logger.info("Save weights") logger.info("Save weights")
self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
else: else:
weight_file = sorted(glob.glob(os.path.join(self.project_dir, "weights*.h5")), key=lambda f:os.path.getmtime(f))[-1]
if not os.path.basename(weight_file) == "weights.h5":
logger.info("Copying latest weight file {} to weights.h5".format(weight_file))
shutil.copy(weight_file, os.path.join(self.project_dir, "weights.h5"))
logger.info("Reloading weights file since we are using model checkpoint!")
self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Reloading weights, since we are using model checkpoint!")
self.total_epochs += epochs self.total_epochs += epochs
self._write_info("epochs", self.total_epochs) self._write_info("epochs", self.total_epochs)
...@@ -808,7 +830,7 @@ class ClassificationProject(object): ...@@ -808,7 +830,7 @@ class ClassificationProject(object):
if "weights" in np_kwargs: if "weights" in np_kwargs:
errors = [] errors = []
for left, right in zip(bins, bins[1:]): for left, right in zip(bins, bins[1:]):
indices = np.where((x >= left) & (x <= right))[0] indices = np.where((x >= left) & (x < right))[0]
sumw2 = np.sum(np_kwargs["weights"][indices]**2) sumw2 = np.sum(np_kwargs["weights"][indices]**2)
content = np.sum(np_kwargs["weights"][indices]) content = np.sum(np_kwargs["weights"][indices])
errors.append(math.sqrt(sumw2)/content) errors.append(math.sqrt(sumw2)/content)
...@@ -1089,7 +1111,7 @@ if __name__ == "__main__": ...@@ -1089,7 +1111,7 @@ if __name__ == "__main__":
#optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
earlystopping_opts=dict(monitor='val_loss', earlystopping_opts=dict(monitor='val_loss',
min_delta=0, patience=2, verbose=0, mode='auto'), min_delta=0, patience=2, verbose=0, mode='auto'),
selection="lep1Pt<5000", # cut out a few very weird outliers selection="1",
branches = ["met", "mt"], branches = ["met", "mt"],
weight_expr = "eventWeight*genWeight", weight_expr = "eventWeight*genWeight",
identifiers = ["DatasetNumber", "EventNumber"], identifiers = ["DatasetNumber", "EventNumber"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment