diff --git a/toolkit.py b/toolkit.py
index 46a48856dd9a85ff7d6b0af862411f276e86cbef..a0be319112eca728362cb312e91d792b46fafa0f 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -38,6 +38,7 @@ from keras.models import Sequential, Model, model_from_json
 from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
 from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
 from keras.optimizers import SGD
+import keras.initializers
 import keras.optimizers
 from keras.utils.vis_utils import model_to_dot
 from keras import backend as K
@@ -188,7 +189,9 @@ class ClassificationProject(object):
 
     :param normalize_weights: normalize the weights to mean 1
 
-    :param ignore_neg_weights: ignore events with negative weights in training (default: False)
+    :param ignore_neg_weights: ignore events with negative weights in training - not recommended! (default: False)
+
+    :param kernel_initializer: weight initializer for the dense layers - if None (default) the keras defaults are used
 
     """
 
@@ -260,7 +263,8 @@ class ClassificationProject(object):
                         mask_value=None,
                         apply_class_weight=True,
                         normalize_weights=True,
-                        ignore_neg_weights=False):
+                        ignore_neg_weights=False,
+                        kernel_initializer=None):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -343,6 +347,7 @@ class ClassificationProject(object):
         self.apply_class_weight = apply_class_weight
         self.normalize_weights = normalize_weights
         self.ignore_neg_weights = ignore_neg_weights
+        self.kernel_initializer = kernel_initializer
 
         self.s_train = None
         self.b_train = None
@@ -734,7 +739,10 @@ class ClassificationProject(object):
                     self.dropout,
                     self.use_bias,
             ):
-                hidden_layer = Dense(node_count, activation=self.activation_function, use_bias=use_bias)(hidden_layer)
+                extra_opts = dict()
+                if self.kernel_initializer is not None:
+                    extra_opts["kernel_initializer"] = getattr(keras.initializers, self.kernel_initializer)()
+                hidden_layer = Dense(node_count, activation=self.activation_function, use_bias=use_bias, **extra_opts)(hidden_layer)
                 if (dropout_fraction is not None) and (dropout_fraction > 0):
                     hidden_layer = Dropout(rate=dropout_fraction)(hidden_layer)
 
@@ -1786,7 +1794,10 @@ class ClassificationProjectRNN(ClassificationProject):
                 flat_channel = flat_input
             combined = concatenate(rnn_channels+[flat_channel])
             for node_count, dropout_fraction in zip(self.nodes, self.dropout):
-                combined = Dense(node_count, activation=self.activation_function)(combined)
+                extra_opts = dict()
+                if self.kernel_initializer is not None:
+                    extra_opts["kernel_initializer"] = getattr(keras.initializers, self.kernel_initializer)()
+                combined = Dense(node_count, activation=self.activation_function, **extra_opts)(combined)
                 if (dropout_fraction is not None) and (dropout_fraction > 0):
                     combined = Dropout(rate=dropout_fraction)(combined)
             combined = Dense(1, activation=self.activation_function_output)(combined)