diff --git a/toolkit.py b/toolkit.py
index 70aa273bfc095469bd66aeb9d1f4d78e4b18d4f7..3d862f5d27f7f36062635a5d906bff5e15bb0a07 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -94,6 +94,10 @@ class KerasROOTClassification(object):
 
     :param earlystopping_opts: options for the keras EarlyStopping callback
 
+    :param random_seed: use this seed value when initialising the model and produce consistent results. Note:
+                        random data is also used for shuffling the training data, so results may vary still. To
+                        produce consistent results, set the numpy random seed before training.
+
     """
 
 
@@ -134,7 +138,8 @@ class KerasROOTClassification(object):
                         step_bkg=2,
                         optimizer="SGD",
                         optimizer_opts=None,
-                        earlystopping_opts=None):
+                        earlystopping_opts=None,
+                        random_seed=1234):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -168,6 +173,8 @@ class KerasROOTClassification(object):
         if not os.path.exists(self.project_dir):
             os.mkdir(self.project_dir)
 
+        self.random_seed = random_seed
+
         self.s_train = None
         self.b_train = None
         self.s_test = None
@@ -434,10 +441,12 @@ class KerasROOTClassification(object):
             Optimizer = getattr(keras.optimizers, self.optimizer)
             optimizer = Optimizer(**self.optimizer_opts)
             logger.info("Compile model")
+            rn_state = np.random.get_state()
+            np.random.seed(self.random_seed)
             self._model.compile(optimizer=optimizer,
                                 loss='binary_crossentropy',
                                 metrics=['accuracy'])
-
+            np.random.set_state(rn_state)
             try:
                 self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
                 logger.info("Found and loaded previously trained weights")
@@ -495,7 +504,6 @@ class KerasROOTClassification(object):
         try:
             self.history = History()
             self.shuffle_training_data()
-            
             self.model.fit(self.x_train,
                            # the reshape might be unnescessary here
                            self.y_train.reshape(-1, 1),
@@ -732,27 +740,26 @@ if __name__ == "__main__":
                                              (filename, "wjets_Sherpa221_NoSys")
                                 ],
                                 optimizer="Adam",
+                                #optimizer="SGD",
                                 #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
                                 earlystopping_opts=dict(monitor='val_loss',
                                     min_delta=0, patience=2, verbose=0, mode='auto'),
-                                # optimizer="Adam",
                                 selection="lep1Pt<5000", # cut out a few very weird outliers
                                 branches = ["met", "mt"],
                                 weight_expr = "eventWeight*genWeight",
                                 identifiers = ["DatasetNumber", "EventNumber"],
                                 step_bkg = 100)
 
+    np.random.seed(42)
     c.train(epochs=20)
     c.plot_ROC()
     c.plot_loss()
     c.plot_accuracy()
 
-    c.write_friend_tree("test4_score",
-                        source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
-                        target_filename="friend.root", target_treename="test4_score")
-
-    np.random.seed(1234)
+    # c.write_friend_tree("test4_score",
+    #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
+    #                     target_filename="friend.root", target_treename="test4_score")
 
-    c.write_friend_tree("test4_score",
-                        source_filename=filename, source_treename="ttbar_NoSys",
-                        target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")
+    # c.write_friend_tree("test4_score",
+    #                     source_filename=filename, source_treename="ttbar_NoSys",
+    #                     target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")