diff --git a/toolkit.py b/toolkit.py
index 425a95b245d63f3a705b0fd6ef3269e4e26cd649..6d860776a02b3bd2ffd58fdac18005c278a80cfd 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 
 import os
+import json
 
 import logging
 logger = logging.getLogger("KerasROOTClassification")
@@ -13,6 +14,22 @@ import h5py
 from sklearn.preprocessing import StandardScaler
 from sklearn.externals import joblib
 
+from keras.models import Sequential
+from keras.layers import Dense
+from keras.models import model_from_json
+
+# configure number of cores
+# this doesn't seem to work, but at least with these settings keras only uses 4 processes
+import tensorflow as tf
+from keras import backend as K
+num_cores = 1
+config = tf.ConfigProto(intra_op_parallelism_threads=num_cores,
+                        inter_op_parallelism_threads=num_cores,
+                        allow_soft_placement=True,
+                        device_count = {'CPU': num_cores})
+session = tf.Session(config=config)
+K.set_session(session)
+
 import ROOT
 
 class KerasROOTClassification:
@@ -23,7 +40,7 @@ class KerasROOTClassification:
 
     def __init__(self, name,
                  signal_trees, bkg_trees, branches, weight_expr, identifiers,
-                 layers=3, nodes=64, out_dir="./outputs"):
+                 layers=3, nodes=64, batch_size=128, activation_function='relu', out_dir="./outputs"):
         self.name = name
         self.signal_trees = signal_trees
         self.bkg_trees = bkg_trees
@@ -32,6 +49,8 @@ class KerasROOTClassification:
         self.identifiers = identifiers
         self.layers = layers
         self.nodes = nodes
+        self.batch_size = batch_size
+        self.activation_function = activation_function
         self.out_dir = out_dir
 
         self.project_dir = os.path.join(self.out_dir, name)
@@ -55,9 +74,16 @@ class KerasROOTClassification:
         self.b_eventlist_train = None
 
         self._scaler = None
+        self._class_weight = None
+        self._model = None
+
+        # track the number of epochs this model has been trained
+        self.total_epochs = 0
 
+        self.data_loaded = False
 
-    def load_data(self):
+
+    def _load_data(self):
 
         try:
 
@@ -108,6 +134,8 @@ class KerasROOTClassification:
             logger.info("Writing to hdf5")
             self._dump_to_hdf5()
 
+        self.data_loaded = True
+
 
     def _dump_training_list(self):
         s_eventlist = pd.DataFrame(self.s_train[self.identifiers])
@@ -149,14 +177,83 @@ class KerasROOTClassification:
         return self._scaler
 
 
-    def _transform_data(self):
-        pass
+    def _read_info(self, key, default):
+        filename = os.path.join(self.project_dir, "info.json")
+        if not os.path.exists(filename):
+            with open(filename, "w") as of:
+                json.dump({}, of)
+        with open(filename) as f:
+            info = json.load(f)
+        return info.get(key, default)
 
-    def _create_model(self):
-        pass
 
-    def train(self):
-        pass
+    def _write_info(self, key, value):
+        filename = os.path.join(self.project_dir, "info.json")
+        with open(filename) as f:
+            info = json.load(f)
+        info[key] = value
+        with open(filename, "w") as of:
+            json.dump(info, of)
+
+
+    @property
+    def model(self):
+        "Simple MLP"
+
+        if self._model is None:
+
+            self._model = Sequential()
+
+            # first hidden layer
+            self._model.add(Dense(self.nodes, input_dim=len(self.branches), activation=self.activation_function))
+            # the other hidden layers
+            for layer_number in range(self.layers-1):
+                self._model.add(Dense(self.nodes, activation=self.activation_function))
+            # last layer is one neuron (binary classification)
+            self._model.add(Dense(1, activation='sigmoid'))
+
+            self._model.compile(optimizer='SGD',
+                  loss='binary_crossentropy',
+                  metrics=['accuracy'])
+
+            # dump to json for documentation
+            with open(os.path.join(self.project_dir, "model.json"), "w") as of:
+                of.write(self._model.to_json())
+
+        return self._model
+
+    @property
+    def class_weight(self):
+        if self._class_weight is None:
+            sumw_bkg = np.sum(self.w_train[self.y_train == 0])
+            sumw_sig = np.sum(self.w_train[self.y_train == 1])
+            self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
+        return self._class_weight
+
+    def train(self, epochs=10):
+
+        if not self.data_loaded:
+            self._load_data()
+
+        try:
+            self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
+            logger.info("Weights found and loaded")
+            logger.info("Continue training")
+        except IOError:
+            logger.info("No weights found, starting completely new training")
+
+        self.total_epochs = self._read_info("epochs", 0)
+
+        self.model.fit(self.x_train, self.y_train,
+                       epochs=epochs,
+                       class_weight=self.class_weight,
+                       shuffle=True,
+                       batch_size=self.batch_size)
+
+        self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
+
+        self.total_epochs += epochs
+        self._write_info("epochs", self.total_epochs)
 
     def evaluate(self):
         pass
@@ -188,8 +285,4 @@ if __name__ == "__main__":
                                 weight_expr = "eventWeight*genWeight",
                                 identifiers = ["DatasetNumber", "EventNumber"])
 
-    c.load_data()
-    print(c.x_train)
-    print(len(c.x_train))
-
-    print(c.scaler)
+    c.train(epochs=1)