From 80673f0e90bddcfd366bb50f19459e1a66b8f427 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Wed, 28 Nov 2018 10:48:59 +0100
Subject: [PATCH] starting adversarial setup for decorrelation

---
 toolkit.py | 39 ++++++++++++++++++++++++++++++++++++++-
 1 file changed, 38 insertions(+), 1 deletion(-)

diff --git a/toolkit.py b/toolkit.py
index d673121..89e3e89 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"]
+__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN", "ClassificationProjectDecorr"]
 
 from sys import version_info
 
@@ -72,6 +72,8 @@ def load_from_dir(path):
         project_type = info["project_type"]
         if project_type == "ClassificationProjectRNN":
             return ClassificationProjectRNN(path)
+        elif project_type == "ClassificationProjectDecorr":
+            return ClassificationProjectDecorr(path)
     except (KeyError, IOError):
         pass
     return ClassificationProject(path)
@@ -2014,6 +2016,41 @@ class ClassificationProjectRNN(ClassificationProject):
         return self.predict(self.get_input_list(x_eval), mode=mode)
 
 
+class ClassificationProjectDecorr(ClassificationProject):
+
+    def __init__(self, *args, **kwargs):
+        super(ClassificationProjectDecorr, self).__init__(*args, **kwargs)
+        self.decorr_binnings = []
+        self.decorr_bins = 10
+        self._write_info("project_type", "ClassificationProjectDecorr")
+
+
+    def load(self, *args, **kwargs):
+        super(ClassificationProjectDecorr, self).load(*args, **kwargs)
+        bin_frac = 1./float(self.decorr_bins)
+        for idx, field_name in enumerate(self.target_fields):
+            # adversary target is fit as multiclass problem with bin indices
+            # (self.decorr_bins quantiles) as labels like in arXiv:1703.03507
+            self.decorr_binnings.append(
+                weighted_quantile(
+                    self.y_train[self.l_train==0][:,idx], # bkg only
+                    np.arange(bin_frac, 1+bin_frac, bin_frac),
+                    sample_weight=self.w_train[self.l_train==0]
+                )
+            )
+
+
+    def get_output_list(self, y):
+        out_list = super(ClassificationProjectDecorr, self).get_output_list(y)
+        for i, (out, binning) in enumerate(
+                zip(out_list[1:], self.decorr_binnings)
+        ):
+            bin_idx = np.digitize(out, binning)
+            # include overflow into last bin
+            bin_idx[bin_idx==len(deciles)] = len(deciles)-1
+            out_list[i] = keras.utils.to_categorical(bin_idx)
+
+
 if __name__ == "__main__":
 
     logging.basicConfig()
-- 
GitLab