From e543ec8d30aa25453b1d8b9fd60fdc7a9e1593a0 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Mon, 13 Aug 2018 10:52:13 +0200
Subject: [PATCH] plot train and test for ROC by default

---
 toolkit.py | 34 +++++++++++++++++++---------------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 05534a2..cc47b00 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -12,7 +12,6 @@ else:
 
 import os
 import json
-import yaml
 import pickle
 import importlib
 import csv
@@ -980,27 +979,32 @@ class ClassificationProject(object):
         plt.close(fig)
 
 
-    def plot_ROC(self):
+    def plot_ROC(self, xlim=(0,1), ylim=(0,1)):
 
         logger.info("Plot ROC curve")
-        fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
-        fpr = 1.0 - fpr
-        try:
-            roc_auc = auc(tpr, fpr, reorder=True)
-        except ValueError:
-            logger.warning("Got a value error from auc - trying to rerun with reorder=True")
-            roc_auc = auc(tpr, fpr, reorder=True)
-
         plt.grid(color='gray', linestyle='--', linewidth=1)
-        plt.plot(tpr,  fpr, label=str(self.name + " (AUC = {})".format(roc_auc)))
+
+        for y, scores, weight, label in [
+                (self.y_train, self.scores_train, self.w_train, "train"),
+                (self.y_test, self.scores_test, self.w_test, "test")
+        ]:
+            fpr, tpr, threshold = roc_curve(y, scores, sample_weight = weight)
+            fpr = 1.0 - fpr # background rejection
+            try:
+                roc_auc = auc(tpr, fpr)
+            except ValueError:
+                logger.warning("Got a value error from auc - trying to rerun with reorder=True")
+                roc_auc = auc(tpr, fpr, reorder=True)
+            plt.plot(tpr,  fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc)))
+
         plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
         plt.ylabel("Background rejection")
         plt.xlabel("Signal efficiency")
         plt.title('Receiver operating characteristic')
-        plt.xlim(0,1)
-        plt.ylim(0,1)
-        plt.xticks(np.arange(0,1,0.1))
-        plt.yticks(np.arange(0,1,0.1))
+        plt.xlim(*xlim)
+        plt.ylim(*ylim)
+        # plt.xticks(np.arange(0,1,0.1))
+        # plt.yticks(np.arange(0,1,0.1))
         plt.legend(loc='lower left', framealpha=1.0)
         plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
         plt.clf()
-- 
GitLab