From 38978a4970c04d3487eba373a86bbd3e08a5869c Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 28 Aug 2018 15:32:19 +0200
Subject: [PATCH] option to turn off weight normalisation

---
 toolkit.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 50cfa16..bb47fb2 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -173,6 +173,10 @@ class ClassificationProject(object):
 
     :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
 
+    :param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background)
+
+    :param normalize_weights: normalize the weights to mean 1
+
     """
 
 
@@ -237,7 +241,8 @@ class ClassificationProject(object):
                         balance_dataset=False,
                         loss='binary_crossentropy',
                         mask_value=None,
-                        apply_class_weight=True):
+                        apply_class_weight=True,
+                        normalize_weights=True):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -310,6 +315,7 @@ class ClassificationProject(object):
         self.loss = loss
         self.mask_value = mask_value
         self.apply_class_weight = apply_class_weight
+        self.normalize_weights = normalize_weights
 
         self.s_train = None
         self.b_train = None
@@ -810,7 +816,8 @@ class ClassificationProject(object):
                 self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
             else:
                 self._w_train_tot = np.array(self.w_train)
-            self._w_train_tot /= np.mean(self._w_train_tot)
+            if self.normalize_weights:
+                self._w_train_tot /= np.mean(self._w_train_tot)
         return self._w_train_tot
 
 
-- 
GitLab