From e543ec8d30aa25453b1d8b9fd60fdc7a9e1593a0 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Mon, 13 Aug 2018 10:52:13 +0200 Subject: [PATCH] plot train and test for ROC by default --- toolkit.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/toolkit.py b/toolkit.py index 05534a2..cc47b00 100755 --- a/toolkit.py +++ b/toolkit.py @@ -12,7 +12,6 @@ else: import os import json -import yaml import pickle import importlib import csv @@ -980,27 +979,32 @@ class ClassificationProject(object): plt.close(fig) - def plot_ROC(self): + def plot_ROC(self, xlim=(0,1), ylim=(0,1)): logger.info("Plot ROC curve") - fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test) - fpr = 1.0 - fpr - try: - roc_auc = auc(tpr, fpr, reorder=True) - except ValueError: - logger.warning("Got a value error from auc - trying to rerun with reorder=True") - roc_auc = auc(tpr, fpr, reorder=True) - plt.grid(color='gray', linestyle='--', linewidth=1) - plt.plot(tpr, fpr, label=str(self.name + " (AUC = {})".format(roc_auc))) + + for y, scores, weight, label in [ + (self.y_train, self.scores_train, self.w_train, "train"), + (self.y_test, self.scores_test, self.w_test, "test") + ]: + fpr, tpr, threshold = roc_curve(y, scores, sample_weight = weight) + fpr = 1.0 - fpr # background rejection + try: + roc_auc = auc(tpr, fpr) + except ValueError: + logger.warning("Got a value error from auc - trying to rerun with reorder=True") + roc_auc = auc(tpr, fpr, reorder=True) + plt.plot(tpr, fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc))) + plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck') plt.ylabel("Background rejection") plt.xlabel("Signal efficiency") plt.title('Receiver operating characteristic') - plt.xlim(0,1) - plt.ylim(0,1) - plt.xticks(np.arange(0,1,0.1)) - plt.yticks(np.arange(0,1,0.1)) + plt.xlim(*xlim) + plt.ylim(*ylim) + # plt.xticks(np.arange(0,1,0.1)) + # plt.yticks(np.arange(0,1,0.1)) plt.legend(loc='lower left', framealpha=1.0) plt.savefig(os.path.join(self.project_dir, "ROC.pdf")) plt.clf() -- GitLab