From 7058afbf6c9b4a5491c296fa9c8303fdb329899b Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 6 Aug 2018 15:11:32 +0200 Subject: [PATCH] Allow reoder=True in auc if values are not increasing (due to neg event weights) --- compare.py | 6 +++++- toolkit.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/compare.py b/compare.py index 55ac03d..7e4c9a8 100755 --- a/compare.py +++ b/compare.py @@ -38,7 +38,11 @@ def overlay_ROC(filename, *projects, **kwargs): for p, color in zip(projects, colors): fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test) fpr = 1.0 - fpr - roc_auc = auc(tpr, fpr) + try: + roc_auc = auc(tpr, fpr) + except ValueError: + logger.warning("Got a value error from auc - trying to rerun with reorder=True") + roc_auc = auc(tpr, fpr, reorder=True) ax.grid(color='gray', linestyle='--', linewidth=1) ax.plot(tpr, fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color) diff --git a/toolkit.py b/toolkit.py index a379413..776ae2b 100755 --- a/toolkit.py +++ b/toolkit.py @@ -911,9 +911,12 @@ class ClassificationProject(object): logger.info("Plot ROC curve") fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test) - fpr = 1.0 - fpr - roc_auc = auc(tpr, fpr) + try: + roc_auc = auc(tpr, fpr, reorder=True) + except ValueError: + logger.warning("Got a value error from auc - trying to rerun with reorder=True") + roc_auc = auc(tpr, fpr, reorder=True) plt.grid(color='gray', linestyle='--', linewidth=1) plt.plot(tpr, fpr, label=str(self.name + " (AUC = {})".format(roc_auc))) -- GitLab