diff --git a/toolkit.py b/toolkit.py index 05534a2c66836acb048ed47879d0f1d3bafc8ca0..cc47b00f661f21d718b7965521755be000221123 100755 --- a/toolkit.py +++ b/toolkit.py @@ -12,7 +12,6 @@ else: import os import json -import yaml import pickle import importlib import csv @@ -980,27 +979,32 @@ class ClassificationProject(object): plt.close(fig) - def plot_ROC(self): + def plot_ROC(self, xlim=(0,1), ylim=(0,1)): logger.info("Plot ROC curve") - fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test) - fpr = 1.0 - fpr - try: - roc_auc = auc(tpr, fpr, reorder=True) - except ValueError: - logger.warning("Got a value error from auc - trying to rerun with reorder=True") - roc_auc = auc(tpr, fpr, reorder=True) - plt.grid(color='gray', linestyle='--', linewidth=1) - plt.plot(tpr, fpr, label=str(self.name + " (AUC = {})".format(roc_auc))) + + for y, scores, weight, label in [ + (self.y_train, self.scores_train, self.w_train, "train"), + (self.y_test, self.scores_test, self.w_test, "test") + ]: + fpr, tpr, threshold = roc_curve(y, scores, sample_weight = weight) + fpr = 1.0 - fpr # background rejection + try: + roc_auc = auc(tpr, fpr) + except ValueError: + logger.warning("Got a value error from auc - trying to rerun with reorder=True") + roc_auc = auc(tpr, fpr, reorder=True) + plt.plot(tpr, fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc))) + plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck') plt.ylabel("Background rejection") plt.xlabel("Signal efficiency") plt.title('Receiver operating characteristic') - plt.xlim(0,1) - plt.ylim(0,1) - plt.xticks(np.arange(0,1,0.1)) - plt.yticks(np.arange(0,1,0.1)) + plt.xlim(*xlim) + plt.ylim(*ylim) + # plt.xticks(np.arange(0,1,0.1)) + # plt.yticks(np.arange(0,1,0.1)) plt.legend(loc='lower left', framealpha=1.0) plt.savefig(os.path.join(self.project_dir, "ROC.pdf")) plt.clf()