Skip to content
Snippets Groups Projects
eval_model.py 1.52 KiB
Newer Older
Nikolai's avatar
Nikolai committed
#!/usr/bin/env python

import os
import argparse

import keras
import h5py
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np

from KerasROOTClassification import ClassificationProject

parser = argparse.ArgumentParser(description='Evaluate a model from a classification project using the given '
                                             'weights and plot the ROC curve and train/test overlayed scores')
parser.add_argument("project_dir")
parser.add_argument("weights")
parser.add_argument("-p", "--plot-prefix", default="eval_nn")
args = parser.parse_args()

c = ClassificationProject(args.project_dir)

c.model.load_weights(args.weights)

print("Predicting for test sample ...")
scores_test = c.evaluate(c.x_test)
print("Done")

fpr, tpr, threshold = roc_curve(c.y_test, scores_test, sample_weight = c.w_test)
fpr = 1.0 - fpr
try:
    roc_auc = auc(tpr, fpr, reorder=True)
except ValueError:
    logger.warning("Got a value error from auc - trying to rerun with reorder=True")
    roc_auc = auc(tpr, fpr, reorder=True)

plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(tpr,  fpr, label=str(c.name + " (AUC = {})".format(roc_auc)))
plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
plt.ylabel("Background rejection")
plt.xlabel("Signal efficiency")
plt.title('Receiver operating characteristic')
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
plt.legend(loc='lower left', framealpha=1.0)
plt.savefig(args.plot_prefix+"_ROC.pdf")
plt.clf()