From 2bffb3024a996a61c6a694072a6ad98689a213c1 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Sun, 26 Aug 2018 16:56:25 +0200
Subject: [PATCH] lumifactor support for overlay_ROC

---
 compare.py | 21 ++++++++++++++++++++-
 1 file changed, 20 insertions(+), 1 deletion(-)

diff --git a/compare.py b/compare.py
index 9ebcb65..2681dca 100755
--- a/compare.py
+++ b/compare.py
@@ -20,6 +20,7 @@ def overlay_ROC(filename, *projects, **kwargs):
     ylim = kwargs.pop("ylim", (0,1))
     plot_thresholds = kwargs.pop("plot_thresholds", False)
     threshold_log = kwargs.pop("threshold_log", True)
+    lumifactor = kwargs.pop("lumifactor", None)
     if kwargs:
         raise KeyError("Unknown kwargs: {}".format(kwargs))
 
@@ -33,6 +34,10 @@ def overlay_ROC(filename, *projects, **kwargs):
         if threshold_log:
             ax2.set_yscale("log")
 
+    if lumifactor is not None:
+        ax_abs_b = ax.twinx()
+        ax_abs_s = ax.twiny()
+
     prop_cycle = plt.rcParams['axes.prop_cycle']
     colors = prop_cycle.by_key()['color']
 
@@ -49,15 +54,29 @@ def overlay_ROC(filename, *projects, **kwargs):
         ax.plot(tpr,  fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color)
         if plot_thresholds:
             ax2.plot(tpr, threshold, "--", color=color)
+        if lumifactor is not None:
+            sumw_b = p.w_test[p.y_test==0].sum()*lumifactor
+            sumw_s = p.w_test[p.y_test==1].sum()*lumifactor
+            ax_abs_b.plot(tpr, (1.-fpr)*sumw_b, "r--", alpha=0.5)
+            ax_abs_b.invert_yaxis()
+            ax_abs_s.plot(tpr*sumw_s, fpr, "g--", alpha=0.5)
 
     if xlim is not None:
         ax.set_xlim(*xlim)
     if ylim is not None:
         ax.set_ylim(*ylim)
+    if lumifactor is not None:
+        ax_abs_b.set_ylim((1-ax.get_ylim()[0])*sumw_b, (1-ax.get_ylim()[1])*sumw_b)
+        ax_abs_b.set_xlim(*ax.get_xlim())
+        ax_abs_s.set_xlim(ax.get_xlim()[0]*sumw_s, ax.get_xlim()[1]*sumw_s)
+        ax_abs_s.set_ylim(*ax.get_ylim())
+        ax_abs_b.set_ylabel("Number of background events")
+        ax_abs_s.set_xlabel("Number of signal events")
     # plt.xticks(np.arange(0,1,0.1))
     # plt.yticks(np.arange(0,1,0.1))
     ax.legend(loc='lower left', framealpha=1.0)
-    ax.set_title('Receiver operating characteristic')
+    if lumifactor is None:
+        ax.set_title('Receiver operating characteristic')
     ax.set_ylabel("Background rejection")
     ax.set_xlabel("Signal efficiency")
     if plot_thresholds:
-- 
GitLab