diff --git a/toolkit.py b/toolkit.py
index 50cfa16fd9ff1f434429fd4f61431e484d4e1c22..bb47fb29baf02c8339d4f93a9f34e0c45e12afae 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -173,6 +173,10 @@ class ClassificationProject(object):
 
     :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
 
+    :param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background)
+
+    :param normalize_weights: normalize the weights to mean 1
+
     """
 
 
@@ -237,7 +241,8 @@ class ClassificationProject(object):
                         balance_dataset=False,
                         loss='binary_crossentropy',
                         mask_value=None,
-                        apply_class_weight=True):
+                        apply_class_weight=True,
+                        normalize_weights=True):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -310,6 +315,7 @@ class ClassificationProject(object):
         self.loss = loss
         self.mask_value = mask_value
         self.apply_class_weight = apply_class_weight
+        self.normalize_weights = normalize_weights
 
         self.s_train = None
         self.b_train = None
@@ -810,7 +816,8 @@ class ClassificationProject(object):
                 self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
             else:
                 self._w_train_tot = np.array(self.w_train)
-            self._w_train_tot /= np.mean(self._w_train_tot)
+            if self.normalize_weights:
+                self._w_train_tot /= np.mean(self._w_train_tot)
         return self._w_train_tot