diff --git a/toolkit.py b/toolkit.py index 8f10b92f0a0659cd4849ca612ebcf3936271a7f7..25231056fb769f6253fe6c825eac65e85451aaf9 100755 --- a/toolkit.py +++ b/toolkit.py @@ -20,7 +20,7 @@ from sklearn.metrics import roc_curve, auc from keras.models import Sequential from keras.layers import Dense from keras.models import model_from_json -from keras.callbacks import History +from keras.callbacks import History, EarlyStopping from keras.optimizers import SGD import keras.optimizers @@ -68,7 +68,9 @@ class KerasROOTClassification(object): step_signal=2, step_bkg=2, optimizer="SGD", - optimizer_opts=None): + optimizer_opts=None, + earlystopping_opts=None): + self.name = name self.signal_trees = signal_trees self.bkg_trees = bkg_trees @@ -89,6 +91,9 @@ class KerasROOTClassification(object): if optimizer_opts is None: optimizer_opts = dict() self.optimizer_opts = optimizer_opts + if earlystopping_opts is None: + earlystopping_opts = dict() + self.earlystopping_opts = earlystopping_opts self.project_dir = os.path.join(self.out_dir, name) @@ -121,6 +126,7 @@ class KerasROOTClassification(object): self._sig_weights = None self._model = None self._history = None + self._callbacks_list = [] # track the number of epochs this model has been trained self.total_epochs = 0 @@ -224,6 +230,15 @@ class KerasROOTClassification(object): logger.info("Data loaded") + @property + def callbacks_list(self): + if not self._callbacks_list: + self._callbacks_list.append(self.history) + self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) + + return self._callbacks_list + + @property def scaler(self): # create the scaler (and fit to training data) if not existent @@ -389,6 +404,7 @@ class KerasROOTClassification(object): try: self.history = History() self.shuffle_training_data() + self.model.fit(self.x_train, # the reshape might be unnescessary here self.y_train.reshape(-1, 1), @@ -398,7 +414,7 @@ class KerasROOTClassification(object): sample_weight=self.w_train, shuffle=True, batch_size=self.batch_size, - callbacks=[self.history]) + callbacks=self.callbacks_list) except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") @@ -583,8 +599,10 @@ if __name__ == "__main__": bkg_trees = [(filename, "ttbar_NoSys"), (filename, "wjets_Sherpa221_NoSys") ], - optimizer="SGD", - optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), + optimizer="Adam", + #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), + earlystopping_opts=dict(monitor='val_loss', + min_delta=0, patience=2, verbose=0, mode='auto'), # optimizer="Adam", selection="lep1Pt<5000", # cut out a few very weird outliers branches = ["met", "mt"],