From db706fce036d0a568b7974589b95c3397dc199f0 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Wed, 28 Nov 2018 13:21:28 +0100
Subject: [PATCH] weights and layers

---
 toolkit.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 80 insertions(+), 7 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 89e3e89..6098d68 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -936,7 +936,7 @@ class ClassificationProject(object):
         x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
         x_val_input = self.get_input_list(self.transform(x_val))
         y_val_output = self.get_output_list(self.transform_target(y_val))
-        w_val_list = self.get_weight_list(w_val)
+        w_val_list = self.get_weight_list(w_val, y_val)
         return x_val_input, y_val_output, w_val_list
 
 
@@ -947,7 +947,7 @@ class ClassificationProject(object):
         x_train, y_train, w_train =  self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
         x_train_input = self.get_input_list(self.transform(x_train))
         y_train_output = self.get_output_list(self.transform_target(y_train))
-        w_train_list = self.get_weight_list(w_train)
+        w_train_list = self.get_weight_list(w_train, y_train)
         return x_train_input, y_train_output, w_train_list
 
 
@@ -991,7 +991,7 @@ class ClassificationProject(object):
             return np.hsplit(y, len(self.target_fields)+1)
 
 
-    def get_weight_list(self, w):
+    def get_weight_list(self, w, y=None):
         "Repeat weight n times for regression targets"
         if not self.target_fields:
             return w
@@ -1018,7 +1018,7 @@ class ClassificationProject(object):
                 w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
                 x_input = self.get_input_list(self.transform(x_batch))
                 y_output = self.get_output_list(self.transform_target(y_batch))
-                w_list = self.get_weight_list(w_batch)
+                w_list = self.get_weight_list(w_batch, y_batch)
                 yield (x_input, y_output, w_list)
 
 
@@ -2023,17 +2023,22 @@ class ClassificationProjectDecorr(ClassificationProject):
         self.decorr_binnings = []
         self.decorr_bins = 10
         self._write_info("project_type", "ClassificationProjectDecorr")
+        self._class_layers = None
+        self._adv_hidden_layers = None
+        self._adv_target_layers = None
 
 
     def load(self, *args, **kwargs):
         super(ClassificationProjectDecorr, self).load(*args, **kwargs)
         bin_frac = 1./float(self.decorr_bins)
+        print(bin_frac)
+        print(np.arange(bin_frac, 1+bin_frac, bin_frac))
         for idx, field_name in enumerate(self.target_fields):
             # adversary target is fit as multiclass problem with bin indices
             # (self.decorr_bins quantiles) as labels like in arXiv:1703.03507
             self.decorr_binnings.append(
                 weighted_quantile(
-                    self.y_train[self.l_train==0][:,idx], # bkg only
+                    self.y_train[self.l_train==0][:,idx+1], # bkg only
                     np.arange(bin_frac, 1+bin_frac, bin_frac),
                     sample_weight=self.w_train[self.l_train==0]
                 )
@@ -2047,8 +2052,76 @@ class ClassificationProjectDecorr(ClassificationProject):
         ):
             bin_idx = np.digitize(out, binning)
             # include overflow into last bin
-            bin_idx[bin_idx==len(deciles)] = len(deciles)-1
-            out_list[i] = keras.utils.to_categorical(bin_idx)
+            bin_idx[bin_idx==len(binning)] = len(binning)-1
+            out_list[i+1] = keras.utils.to_categorical(bin_idx)
+        return out_list
+
+
+    def get_weight_list(self, w, y):
+        w_list = super(ClassificationProjectDecorr, self).get_weight_list(w)
+        # copy first entry (the others might be references)
+        w_list[0] = np.array(w_list[0])
+        for w in w_list[1:]:
+            # set signal weights to 0 for decorr target
+            w[y[:,0]==1] = 0.
+        return w_list
+
+
+    @property
+    def class_layers(self):
+        """
+        Layers for the classification model
+
+        This should be generalised to avoid code duplication with the model
+        functions of the base classes
+        """
+        if self._class_layers is None:
+            layers = []
+            self._class_layers = layers
+            layers.append(Input((len(self.fields),)))
+            if self.dropout_input is not None:
+                layers.append(Dropout(rate=self.dropout_input))
+            for node_count, dropout_fraction, use_bias in zip(
+                    self.nodes,
+                    self.dropout,
+                    self.use_bias,
+            ):
+                layers.append(
+                    Dense(
+                        node_count,
+                        activation=self.activation_function,
+                        use_bias=use_bias
+                    )
+                )
+                if (dropout_fraction is not None) and (dropout_fraction > 0):
+                    layers.append(Dropout(rate=dropout_fraction))
+            layers.append(Dense(1, activation=self.activation_function_output))
+        return self._class_layers
+
+
+    @property
+    def adv_layers(self):
+        """
+        Layers for the adversary
+        """
+        if self._adv_hidden_layers is None:
+            self._adv_hidden_layers = []
+            self._adv_target_layers = []
+            self._adv_hidden_layers.append(Dense(128, activation="tanh"))
+            for binning in self.decorr_binnings:
+                layer = Dense(len(binning), activation="softmax")
+                self._adv_target_layers.append(layer)
+        return self._adv_hidden_layers+self._adv_target_layers
+
+
+    @property
+    def class_input(self):
+        pass
+
+
+    @property
+    def model(self):
+        pass
 
 
 if __name__ == "__main__":
-- 
GitLab