diff --git a/toolkit.py b/toolkit.py
index 3db07b0d43521e67499ccd62e0df2162ddbf7322..789467549887f87d25c988713db4fc3429e294a5 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -2,6 +2,7 @@
 
 import os
 import json
+import pickle
 
 import logging
 logger = logging.getLogger("KerasROOTClassification")
@@ -18,6 +19,7 @@ from sklearn.metrics import roc_curve, auc
 from keras.models import Sequential
 from keras.layers import Dense
 from keras.models import model_from_json
+from keras.callbacks import History
 import matplotlib.pyplot as plt
 
 import matplotlib.pyplot as plt
@@ -256,6 +258,33 @@ class KerasROOTClassification(object):
         return self._scaler
 
 
+    @property
+    def history(self):
+        params_file = os.path.join(self.project_dir, "history_params.json")
+        history_file = os.path.join(self.project_dir, "history_history.json")
+        if self._history is None:
+            self._history = History()
+            with open(params_file) as f:
+                self._history.params = json.load(f)
+            with open(history_file) as f:
+                self._history.history = json.load(f)
+        return self._history
+
+
+    @history.setter
+    def history(self, value):
+        self._history = value
+
+
+    def _dump_history(self):
+        params_file = os.path.join(self.project_dir, "history_params.json")
+        history_file = os.path.join(self.project_dir, "history_history.json")
+        with open(params_file, "wb") as of:
+            json.dump(self.history.params, of)
+        with open(history_file, "wb") as of:
+            json.dump(self.history.history, of)
+
+
     def _transform_data(self):
         if not self.data_transformed:
             # todo: what to do about the outliers? Where do they come from?
@@ -324,14 +353,28 @@ class KerasROOTClassification(object):
         return self._class_weight
 
 
-    def train(self, epochs=10):
-
+    def load(self):
+        "Load all data needed for plotting and training"
         if not self.data_loaded:
             self._load_data()
 
         if not self.data_transformed:
             self._transform_data()
 
+
+    def shuffle_training_data(self):
+        rn_state = np.random.get_state()
+        np.random.shuffle(self.x_train)
+        np.random.set_state(rn_state)
+        np.random.shuffle(self.y_train)
+        np.random.set_state(rn_state)
+        np.random.shuffle(self.w_train)
+
+
+    def train(self, epochs=10):
+
+        self.load()
+
         for branch_index, branch in enumerate(self.branches):
             self.plot_input(branch_index)
 
@@ -345,14 +388,26 @@ class KerasROOTClassification(object):
         self.total_epochs = self._read_info("epochs", 0)
 
         logger.info("Train model")
-        self._history = self.model.fit(self.x_train,
-                                       # the reshape might be unnescessary here
-                                       self.y_train.reshape(-1, 1),
-                                       epochs=epochs,
-                                       validation_split = self.validation_split,
-                                       class_weight=self.class_weight,
-                                       shuffle=True,
-                                       batch_size=self.batch_size)
+        try:
+            self.history = History()
+            self.shuffle_training_data()
+            self.model.fit(self.x_train,
+                           # the reshape might be unnescessary here
+                           self.y_train.reshape(-1, 1),
+                           epochs=epochs,
+                           validation_split = self.validation_split,
+                           class_weight=self.class_weight,
+                           sample_weight=self.w_train,
+                           shuffle=True,
+                           batch_size=self.batch_size,
+                           callbacks=[self.history])
+        except KeyboardInterrupt:
+            logger.info("Interrupt training - continue with rest")
+
+        print(self.history)
+
+        logger.info("Save history")
+        self._dump_history()
 
         logger.info("Save weights")
         self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
@@ -467,8 +522,8 @@ class KerasROOTClassification(object):
     def plot_loss(self):
 
         logger.info("Plot losses")
-        plt.plot(self._history.history['loss'])
-        plt.plot(self._history.history['val_loss'])
+        plt.plot(self.history.history['loss'])
+        plt.plot(self.history.history['val_loss'])
         plt.ylabel('loss')
         plt.xlabel('epoch')
         plt.legend(['train','test'], loc='upper left')
@@ -479,8 +534,8 @@ class KerasROOTClassification(object):
     def plot_accuracy(self):
         
         logger.info("Plot accuracy")
-        plt.plot(self._history.history['acc'])
-        plt.plot(self._history.history['val_acc'])
+        plt.plot(self.history.history['acc'])
+        plt.plot(self.history.history['val_acc'])
         plt.title('model accuracy')
         plt.ylabel('accuracy')
         plt.xlabel('epoch')
@@ -507,7 +562,8 @@ if __name__ == "__main__":
                                 identifiers = ["DatasetNumber", "EventNumber"],
                                 step_bkg = 100)
 
-#    c.train(epochs=10)
+    #c.load()
+    c.train(epochs=10)
     c.plot_ROC()
     c.plot_loss()
     c.plot_accuracy()