diff --git a/compare.py b/compare.py index 4a266c0cd555fbb74749069abe49cab38affd2f7..55ac03dc952746107f00eb9ec26863b8aad05ead 100755 --- a/compare.py +++ b/compare.py @@ -17,31 +17,49 @@ def overlay_ROC(filename, *projects, **kwargs): xlim = kwargs.pop("xlim", (0,1)) ylim = kwargs.pop("ylim", (0,1)) + plot_thresholds = kwargs.pop("plot_thresholds", False) + threshold_log = kwargs.pop("threshold_log", True) if kwargs: raise KeyError("Unknown kwargs: {}".format(kwargs)) logger.info("Overlay ROC curves for {}".format([p.name for p in projects])) - for p in projects: + fig, ax = plt.subplots() + + if plot_thresholds: + ax2 = ax.twinx() + ax2.set_ylabel("Thresholds") + if threshold_log: + ax2.set_yscale("log") + + prop_cycle = plt.rcParams['axes.prop_cycle'] + colors = prop_cycle.by_key()['color'] + + for p, color in zip(projects, colors): fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test) fpr = 1.0 - fpr roc_auc = auc(tpr, fpr) - plt.grid(color='gray', linestyle='--', linewidth=1) - plt.plot(tpr, fpr, label=str(p.name+" (AUC = {})".format(roc_auc))) + ax.grid(color='gray', linestyle='--', linewidth=1) + ax.plot(tpr, fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color) + if plot_thresholds: + ax2.plot(tpr, threshold, "--", color=color) if xlim is not None: - plt.xlim(*xlim) + ax.set_xlim(*xlim) if ylim is not None: - plt.ylim(*ylim) + ax.set_ylim(*ylim) # plt.xticks(np.arange(0,1,0.1)) # plt.yticks(np.arange(0,1,0.1)) - plt.legend(loc='lower left', framealpha=1.0) - plt.title('Receiver operating characteristic') - plt.ylabel("Background rejection") - plt.xlabel("Signal efficiency") - plt.savefig(filename) - plt.clf() + ax.legend(loc='lower left', framealpha=1.0) + ax.set_title('Receiver operating characteristic') + ax.set_ylabel("Background rejection") + ax.set_xlabel("Signal efficiency") + if plot_thresholds: + # to fit right y-axis description + fig.tight_layout() + fig.savefig(filename) + plt.close(fig) def overlay_loss(filename, *projects, **kwargs):