diff --git a/toolkit.py b/toolkit.py
index 3c658692a22cd1c9472a9d0af502090013858aac..96fb4c99a07ca1d2754f727f74e04b51ce992da6 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -230,6 +230,7 @@ class ClassificationProject(object):
 
         self._scaler = None
         self._class_weight = None
+        self._balanced_class_weight = None
         self._model = None
         self._history = None
         self._callbacks_list = []
@@ -534,10 +535,29 @@ class ClassificationProject(object):
             sumw_bkg = np.sum(self.w_train[self.y_train == 0])
             sumw_sig = np.sum(self.w_train[self.y_train == 1])
             self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
-        logger.debug("Calculated class_weight: {}".format(self._class_weight))
+            logger.debug("Calculated class_weight: {}".format(self._class_weight))
         return self._class_weight
 
 
+    @property
+    def balanced_class_weight(self):
+        """
+        Class weight for the balance_dataset method
+        Since we have equal number of signal and background events in
+        each batch, we need to balance the ratio of sum of weights per
+        event with class weights
+        """
+        if self._balanced_class_weight is None:
+            sumw_bkg = np.sum(self.w_train[self.y_train == 0])
+            sumw_sig = np.sum(self.w_train[self.y_train == 1])
+            # use sumw *per event* in this case
+            sumw_bkg /= len(self.w_train[self.y_train == 0])
+            sumw_sig /= len(self.w_train[self.y_train == 1])
+            self._balanced_class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
+            logger.debug("Calculated balanced_class_weight: {}".format(self._balanced_class_weight))
+        return self._balanced_class_weight
+
+
     def load(self, reload=False):
         "Load all data needed for plotting and training"
 
@@ -601,13 +621,19 @@ class ClassificationProject(object):
             for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
                 yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
                        y_train[y_train==class_label][start:start+int(self.batch_size/2)],
-                       w_train[y_train==class_label][start:start+int(self.batch_size/2)])
+                       w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label])
             # restart
 
 
     def yield_balanced_batch(self):
         "generate batches with equal amounts of both classes"
+        logcounter = 0
         for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)):
+            if logcounter == 10:
+                logger.debug("\rSumw sig*balanced_class_weight[1]: {}".format(np.sum(batch_1[2])))
+                logger.debug("\rSumw bkg*balanced_class_weight[0]: {}".format(np.sum(batch_0[2])))
+                logcounter = 0
+            logcounter += 1
             yield (np.concatenate((batch_0[0], batch_1[0])),
                    np.concatenate((batch_0[1], batch_1[1])),
                    np.concatenate((batch_0[2], batch_1[2])))
@@ -645,6 +671,7 @@ class ClassificationProject(object):
                 self.is_training = True
                 labels, label_counts = np.unique(self.y_train, return_counts=True)
                 logger.info("Training on balanced batches")
+                # note: the batches have balanced_class_weight already applied
                 self.model.fit_generator(self.yield_balanced_batch(),
                                          steps_per_epoch=int(min(label_counts)/self.batch_size),
                                          epochs=epochs,