Skip to content
Snippets Groups Projects
Commit 423273cd authored by Thomas Weber's avatar Thomas Weber
Browse files

Adding EarlyStop

Stop training when a monitored quantity has stopped improving.
parent b4048320
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ from sklearn.metrics import roc_curve, auc ...@@ -20,7 +20,7 @@ from sklearn.metrics import roc_curve, auc
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Dense from keras.layers import Dense
from keras.models import model_from_json from keras.models import model_from_json
from keras.callbacks import History from keras.callbacks import History, EarlyStopping
from keras.optimizers import SGD from keras.optimizers import SGD
import keras.optimizers import keras.optimizers
...@@ -68,7 +68,9 @@ class KerasROOTClassification(object): ...@@ -68,7 +68,9 @@ class KerasROOTClassification(object):
step_signal=2, step_signal=2,
step_bkg=2, step_bkg=2,
optimizer="SGD", optimizer="SGD",
optimizer_opts=None): optimizer_opts=None,
earlystopping_opts=None):
self.name = name self.name = name
self.signal_trees = signal_trees self.signal_trees = signal_trees
self.bkg_trees = bkg_trees self.bkg_trees = bkg_trees
...@@ -89,6 +91,9 @@ class KerasROOTClassification(object): ...@@ -89,6 +91,9 @@ class KerasROOTClassification(object):
if optimizer_opts is None: if optimizer_opts is None:
optimizer_opts = dict() optimizer_opts = dict()
self.optimizer_opts = optimizer_opts self.optimizer_opts = optimizer_opts
if earlystopping_opts is None:
earlystopping_opts = dict()
self.earlystopping_opts = earlystopping_opts
self.project_dir = os.path.join(self.out_dir, name) self.project_dir = os.path.join(self.out_dir, name)
...@@ -121,6 +126,7 @@ class KerasROOTClassification(object): ...@@ -121,6 +126,7 @@ class KerasROOTClassification(object):
self._sig_weights = None self._sig_weights = None
self._model = None self._model = None
self._history = None self._history = None
self._callbacks_list = []
# track the number of epochs this model has been trained # track the number of epochs this model has been trained
self.total_epochs = 0 self.total_epochs = 0
...@@ -224,6 +230,15 @@ class KerasROOTClassification(object): ...@@ -224,6 +230,15 @@ class KerasROOTClassification(object):
logger.info("Data loaded") logger.info("Data loaded")
@property
def callbacks_list(self):
if not self._callbacks_list:
self._callbacks_list.append(self.history)
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
return self._callbacks_list
@property @property
def scaler(self): def scaler(self):
# create the scaler (and fit to training data) if not existent # create the scaler (and fit to training data) if not existent
...@@ -389,6 +404,7 @@ class KerasROOTClassification(object): ...@@ -389,6 +404,7 @@ class KerasROOTClassification(object):
try: try:
self.history = History() self.history = History()
self.shuffle_training_data() self.shuffle_training_data()
self.model.fit(self.x_train, self.model.fit(self.x_train,
# the reshape might be unnescessary here # the reshape might be unnescessary here
self.y_train.reshape(-1, 1), self.y_train.reshape(-1, 1),
...@@ -398,7 +414,7 @@ class KerasROOTClassification(object): ...@@ -398,7 +414,7 @@ class KerasROOTClassification(object):
sample_weight=self.w_train, sample_weight=self.w_train,
shuffle=True, shuffle=True,
batch_size=self.batch_size, batch_size=self.batch_size,
callbacks=[self.history]) callbacks=self.callbacks_list)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest") logger.info("Interrupt training - continue with rest")
...@@ -583,8 +599,10 @@ if __name__ == "__main__": ...@@ -583,8 +599,10 @@ if __name__ == "__main__":
bkg_trees = [(filename, "ttbar_NoSys"), bkg_trees = [(filename, "ttbar_NoSys"),
(filename, "wjets_Sherpa221_NoSys") (filename, "wjets_Sherpa221_NoSys")
], ],
optimizer="SGD", optimizer="Adam",
optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
earlystopping_opts=dict(monitor='val_loss',
min_delta=0, patience=2, verbose=0, mode='auto'),
# optimizer="Adam", # optimizer="Adam",
selection="lep1Pt<5000", # cut out a few very weird outliers selection="lep1Pt<5000", # cut out a few very weird outliers
branches = ["met", "mt"], branches = ["met", "mt"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment