diff --git a/toolkit.py b/toolkit.py
index 8f10b92f0a0659cd4849ca612ebcf3936271a7f7..25231056fb769f6253fe6c825eac65e85451aaf9 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -20,7 +20,7 @@ from sklearn.metrics import roc_curve, auc
 from keras.models import Sequential
 from keras.layers import Dense
 from keras.models import model_from_json
-from keras.callbacks import History
+from keras.callbacks import History, EarlyStopping
 from keras.optimizers import SGD
 import keras.optimizers
 
@@ -68,7 +68,9 @@ class KerasROOTClassification(object):
                         step_signal=2,
                         step_bkg=2,
                         optimizer="SGD",
-                        optimizer_opts=None):
+                        optimizer_opts=None,
+                        earlystopping_opts=None):
+
         self.name = name
         self.signal_trees = signal_trees
         self.bkg_trees = bkg_trees
@@ -89,6 +91,9 @@ class KerasROOTClassification(object):
         if optimizer_opts is None:
             optimizer_opts = dict()
         self.optimizer_opts = optimizer_opts
+        if earlystopping_opts is None:
+            earlystopping_opts = dict()
+        self.earlystopping_opts = earlystopping_opts
 
         self.project_dir = os.path.join(self.out_dir, name)
 
@@ -121,6 +126,7 @@ class KerasROOTClassification(object):
         self._sig_weights = None
         self._model = None
         self._history = None
+        self._callbacks_list = []
 
         # track the number of epochs this model has been trained
         self.total_epochs = 0
@@ -224,6 +230,15 @@ class KerasROOTClassification(object):
         logger.info("Data loaded")
 
 
+    @property
+    def callbacks_list(self):
+        if not self._callbacks_list:
+            self._callbacks_list.append(self.history)
+            self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
+
+        return self._callbacks_list
+
+
     @property
     def scaler(self):
         # create the scaler (and fit to training data) if not existent
@@ -389,6 +404,7 @@ class KerasROOTClassification(object):
         try:
             self.history = History()
             self.shuffle_training_data()
+            
             self.model.fit(self.x_train,
                            # the reshape might be unnescessary here
                            self.y_train.reshape(-1, 1),
@@ -398,7 +414,7 @@ class KerasROOTClassification(object):
                            sample_weight=self.w_train,
                            shuffle=True,
                            batch_size=self.batch_size,
-                           callbacks=[self.history])
+                           callbacks=self.callbacks_list)
         except KeyboardInterrupt:
             logger.info("Interrupt training - continue with rest")
 
@@ -583,8 +599,10 @@ if __name__ == "__main__":
                                 bkg_trees = [(filename, "ttbar_NoSys"),
                                              (filename, "wjets_Sherpa221_NoSys")
                                 ],
-                                optimizer="SGD",
-                                optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
+                                optimizer="Adam",
+                                #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
+                                earlystopping_opts=dict(monitor='val_loss',
+                                    min_delta=0, patience=2, verbose=0, mode='auto'),
                                 # optimizer="Adam",
                                 selection="lep1Pt<5000", # cut out a few very weird outliers
                                 branches = ["met", "mt"],