From 1fce12f7396ee135bcd11b83fac5bf217db678e0 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Mon, 20 Aug 2018 10:17:05 +0200
Subject: [PATCH] restructure plot_score function

---
 toolkit.py | 49 ++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 36 insertions(+), 13 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 52b911f..3896daf 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1097,19 +1097,36 @@ class ClassificationProject(object):
         plt.clf()
 
 
-    def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)), ylim=None, xlim=None):
-        centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
-        centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
-        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
-        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
-        errors_sig_test = hist_sig_test*rel_errors_sig_test
-        errors_bkg_test = hist_bkg_test*rel_errors_bkg_test
+    def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)),
+                   ylim=None, xlim=None, density=True, lumifactor=None, apply_class_weight=True):
         fig, ax = plt.subplots()
-        width = centers_sig_train[1]-centers_sig_train[0]
-        ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
-        ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
-        ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test")
-        ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test")
+        for scores, weights, y, class_label, fn, opts in [
+                (self.scores_train, self.w_train, self.y_train, 1, ax.bar, dict(color="r", label="signal train")),
+                (self.scores_train, self.w_train, self.y_train, 0, ax.bar, dict(color="b", label="background train")),
+                (self.scores_test, self.w_test, self.y_test, 1, ax.errorbar, dict(fmt="ro", label="signal test")),
+                (self.scores_test, self.w_test, self.y_test, 0, ax.errorbar, dict(fmt="bo", label="background test")),
+        ]:
+            weights = weights[y==class_label]
+            if apply_class_weight is True and (lumifactor is not None):
+                logger.warning("not applying class weight, since lumifactor given")
+            if apply_class_weight and (lumifactor is None):
+                weights = weights*self.class_weight[class_label]
+            if lumifactor is not None:
+                weights = weights*lumifactor
+            centers, hist, rel_errors = self.get_bin_centered_hist(
+                scores[y==class_label].reshape(-1),
+                weights=weights,
+                **plot_opts
+            )
+            width = centers[1]-centers[0]
+            if density:
+                hist = hist/width
+            if fn == ax.errorbar:
+                errors = rel_errors*hist
+                opts.update(yerr=errors)
+            else:
+                opts.update(width=width, alpha=0.5)
+            fn(centers, hist, **opts)
         if log:
             ax.set_yscale("log")
         if ylim is not None:
@@ -1117,7 +1134,13 @@ class ClassificationProject(object):
         if xlim is not None:
             ax.set_xlim(*xlim)
         ax.set_xlabel("NN output")
-        fig.legend(loc='upper center', framealpha=0.5)
+        if density:
+            ax.set_ylabel("dN / d(NN output)")
+        else:
+            ax.set_ylabel("Events / {:.2f}".format(width))
+        if apply_class_weight:
+            ax.set_title("Class weights applied")
+        ax.legend(loc='upper center', framealpha=0.5)
         fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
         plt.close(fig)
 
-- 
GitLab