From 53ba3604e33b7b94aa603833d0e4786bec62616a Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Fri, 30 Nov 2018 15:09:20 +0100
Subject: [PATCH] apply class weights for adversarial targets

---
 toolkit.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 61fd512..947968c 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -2062,11 +2062,12 @@ class ClassificationProjectDecorr(ClassificationProject):
 
     @property
     def class_weight_target(self):
+        "to weight the adversarial target to have equal sum of weights per bin"
         if self._class_weight_target is None:
             self._class_weight_target = []
             for var_i, binning in enumerate(self.decorr_binnings, 1):
                 sumw = self.w_train[self.l_train==0].sum()
-                class_weight = [
+                class_weight = np.array([
                     sumw/(
                     len(binning)
                     * self.w_train[
@@ -2075,7 +2076,7 @@ class ClassificationProjectDecorr(ClassificationProject):
                     ].sum()
                     )
                     for label in range(len(binning))
-                ]
+                ])
                 self._class_weight_target.append(class_weight)
         return self._class_weight_target
 
@@ -2099,6 +2100,11 @@ class ClassificationProjectDecorr(ClassificationProject):
             w_list[i] = np.array(w_list[i])
             # set signal weights to 0 for decorr target
             w_list[i][y[:,0]==1] = 0.
+            bin_labels = np.argmax(
+                self.get_output_list(y[y[:,0]==0])[1],
+                axis=1
+            )
+            w_list[i][y[:,0]==0] *= self.class_weight_target[i-1][bin_labels]
         return w_list
 
 
-- 
GitLab