From 2e2344e75ad9225d4ea49af7b9cb38d528851b96 Mon Sep 17 00:00:00 2001
From: Eric Schanet <eric.schanet@cern.ch>
Date: Fri, 27 Apr 2018 17:41:21 +0200
Subject: [PATCH] Fixing ROC AUC computation

---
 toolkit.py | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index f4451f3..9dfe91f 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -13,7 +13,7 @@ import pandas as pd
 import h5py
 from sklearn.preprocessing import StandardScaler, RobustScaler
 from sklearn.externals import joblib
-from sklearn.metrics import roc_curve
+from sklearn.metrics import roc_curve, auc
 
 from keras.models import Sequential
 from keras.layers import Dense
@@ -128,7 +128,7 @@ class KerasROOTClassification:
             self.b_train = tree2array(bkg_chain,
                                       branches=self.branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
-                                      start=0, step=2)
+                                      start=0, step=200)
             self.s_test = tree2array(signal_chain,
                                      branches=self.branches+[self.weight_expr],
                                      selection=self.selection,
@@ -136,7 +136,7 @@ class KerasROOTClassification:
             self.b_test = tree2array(bkg_chain,
                                      branches=self.branches+[self.weight_expr],
                                      selection=self.selection,
-                                     start=1, step=2)
+                                     start=1, step=200)
 
             self._dump_training_list()
             self.s_eventlist_train = self.s_train[self.identifiers]
@@ -404,13 +404,15 @@ class KerasROOTClassification:
 
         logger.info("Plot ROC curve")
         fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
+
         fpr = 1.0 - fpr
+        roc_auc = auc(tpr, fpr)
 
         plt.grid(color='gray', linestyle='--', linewidth=1)
-        plt.plot(fpr, tpr, label='NN')
+        plt.plot(tpr,  fpr, label='NN')
         plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
-        plt.xlabel("False positive rate (background rejection)")
-        plt.ylabel("True positive rate (signal efficiency)")
+        plt.xlabel("Background rejection")
+        plt.xlabel("Signal efficiency")
         plt.title('Receiver operating characteristic')
         plt.xlim(0,1)
         plt.ylim(0,1)
@@ -421,7 +423,6 @@ class KerasROOTClassification:
         plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
         plt.clf()
 
-
     def plot_score(self):
         pass
 
-- 
GitLab