Skip to content
Snippets Groups Projects
Unverified Commit 6ce509e4 authored by Eric Schanet's avatar Eric Schanet
Browse files

Some custom naming for ROC curve

parent dc93ef02
No related branches found
No related tags found
No related merge requests found
......@@ -946,7 +946,7 @@ class ClassificationProject(object):
plt.close(fig)
def plot_ROC(self):
def plot_ROC(self, truth=True):
logger.info("Plot ROC curve")
fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
......@@ -957,18 +957,25 @@ class ClassificationProject(object):
logger.warning("Got a value error from auc - trying to rerun with reorder=True")
roc_auc = auc(tpr, fpr, reorder=True)
if truth:
plot_name = "ROC_truth.pdf"
legend_name = "Truth test"
else:
plot_name = "ROC_reco.pdf"
legend_name = "Reco test"
plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(tpr, fpr, label=str(self.name + " (AUC = {})".format(roc_auc)))
plt.plot(tpr, fpr, label=str(legend_name + " (AUC = {})".format(roc_auc)))
plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
plt.ylabel("Background rejection")
plt.xlabel("Signal efficiency")
plt.title('Receiver operating characteristic')
# plt.title('Receiver operating characteristic')
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
plt.legend(loc='lower left', framealpha=1.0)
plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
plt.savefig(os.path.join(self.project_dir, plot_name))
plt.clf()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment