From f205e49d422279bcccc88fb642c52d9e2dc4c1ab Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 8 May 2018 17:42:14 +0200
Subject: [PATCH] scaling hists for input plots by class weight instead of
 multiplying to sample weights

---
 toolkit.py | 51 +++++++++++++++++----------------------------------
 1 file changed, 17 insertions(+), 34 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index ee64a09..205b3ca 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -198,8 +198,6 @@ class KerasROOTClassification(object):
 
         self._scaler = None
         self._class_weight = None
-        self._bkg_weights = None
-        self._sig_weights = None
         self._model = None
         self._history = None
         self._callbacks_list = []
@@ -586,34 +584,12 @@ class KerasROOTClassification(object):
         pass
 
 
-    @property
-    def bkg_weights(self):
-        """
-        class weights multiplied by event weights (for plotting)
-        TODO: find a better way to do this
-        """
-        if self._bkg_weights is None:
-            logger.debug("Calculating background weights for plotting")
-            self._bkg_weights = np.empty(sum(self.y_train == 0))
-            self._bkg_weights.fill(self.class_weight[0])
-            self._bkg_weights *= self.w_train[self.y_train == 0]
-            logger.debug("Background weights: {}".format(self._bkg_weights))
-        return self._bkg_weights
-
-
-    @property
-    def sig_weights(self):
-        """
-        class weights multiplied by event weights (for plotting)
-        TODO: find a better way to do this
-        """
-        if self._sig_weights is None:
-            logger.debug("Calculating signal weights for plotting")
-            self._sig_weights = np.empty(sum(self.y_train == 1))
-            self._sig_weights.fill(self.class_weight[1])
-            self._sig_weights *= self.w_train[self.y_train == 1]
-            logger.debug("Signal weights: {}".format(self._sig_weights))
-        return self._sig_weights
+    @staticmethod
+    def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
+        hist, bins = np.histogram(x, **np_kwargs)
+        centers = (bins[:-1] + bins[1:]) / 2
+        hist *= scale_factor
+        return centers, hist
 
 
     def plot_input(self, var_index):
@@ -622,6 +598,8 @@ class KerasROOTClassification(object):
         fig, ax = plt.subplots()
         bkg = self.x_train[:,var_index][self.y_train == 0]
         sig = self.x_train[:,var_index][self.y_train == 1]
+        bkg_weights = self.w_train[self.y_train == 0]
+        sig_weights = self.w_train[self.y_train == 1]
 
         logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
         logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
@@ -635,14 +613,19 @@ class KerasROOTClassification(object):
         logger.debug("Calculated range based on percentiles: {}".format(plot_range))
 
         try:
-            ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
-            ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
+            centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
         except ValueError:
             # weird, probably not always working workaround for a numpy bug
             plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
             logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
-            ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
-            ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
+            centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
+
+        width = centers_sig[1]-centers_sig[0]
+        ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
+        ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
+
         ax.set_xlabel(branch+" (transformed)")
         plot_dir = os.path.join(self.project_dir, "plots")
         if not os.path.exists(plot_dir):
-- 
GitLab