diff --git a/toolkit.py b/toolkit.py
index 61fd512d71e36d9f8469d792052d435c2f509423..947968c6a6d250db87d27511f125cad583198915 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -2062,11 +2062,12 @@ class ClassificationProjectDecorr(ClassificationProject):
 
     @property
     def class_weight_target(self):
+        "to weight the adversarial target to have equal sum of weights per bin"
         if self._class_weight_target is None:
             self._class_weight_target = []
             for var_i, binning in enumerate(self.decorr_binnings, 1):
                 sumw = self.w_train[self.l_train==0].sum()
-                class_weight = [
+                class_weight = np.array([
                     sumw/(
                     len(binning)
                     * self.w_train[
@@ -2075,7 +2076,7 @@ class ClassificationProjectDecorr(ClassificationProject):
                     ].sum()
                     )
                     for label in range(len(binning))
-                ]
+                ])
                 self._class_weight_target.append(class_weight)
         return self._class_weight_target
 
@@ -2099,6 +2100,11 @@ class ClassificationProjectDecorr(ClassificationProject):
             w_list[i] = np.array(w_list[i])
             # set signal weights to 0 for decorr target
             w_list[i][y[:,0]==1] = 0.
+            bin_labels = np.argmax(
+                self.get_output_list(y[y[:,0]==0])[1],
+                axis=1
+            )
+            w_list[i][y[:,0]==0] *= self.class_weight_target[i-1][bin_labels]
         return w_list