diff --git a/toolkit.py b/toolkit.py
index 5280d07aedec891605415b0283fd12882bbc1017..09fcacd840250896811b076ee5fdce32f3cb05fc 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -236,7 +236,8 @@ class ClassificationProject(object):
                         random_seed=1234,
                         balance_dataset=False,
                         loss='binary_crossentropy',
-                        mask_value=None):
+                        mask_value=None,
+                        apply_class_weight=True):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -308,6 +309,7 @@ class ClassificationProject(object):
         self.balance_dataset = balance_dataset
         self.loss = loss
         self.mask_value = mask_value
+        self.apply_class_weight = apply_class_weight
 
         self.s_train = None
         self.b_train = None
@@ -804,7 +806,10 @@ class ClassificationProject(object):
         else:
             class_weight = self.balanced_class_weight
         if self._w_train_tot is None:
-            self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
+            if self.apply_class_weight:
+                self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
+            else:
+                self._w_train_tot = np.array(self.w_train)
             self._w_train_tot /= np.mean(self._w_train_tot)
         return self._w_train_tot