From 1f5ff1b3afb9911b8ff45f11ddc1b625b953cc6c Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Sun, 12 Aug 2018 11:30:56 +0200
Subject: [PATCH] staring eval_model script

---
 scripts/eval_model.py | 50 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 50 insertions(+)
 create mode 100755 scripts/eval_model.py

diff --git a/scripts/eval_model.py b/scripts/eval_model.py
new file mode 100755
index 0000000..63c4bbd
--- /dev/null
+++ b/scripts/eval_model.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python
+
+import os
+import argparse
+
+import keras
+import h5py
+from sklearn.metrics import roc_curve, auc
+import matplotlib.pyplot as plt
+import numpy as np
+
+from KerasROOTClassification import ClassificationProject
+
+parser = argparse.ArgumentParser(description='Evaluate a model from a classification project using the given '
+                                             'weights and plot the ROC curve and train/test overlayed scores')
+parser.add_argument("project_dir")
+parser.add_argument("weights")
+parser.add_argument("-p", "--plot-prefix", default="eval_nn")
+args = parser.parse_args()
+
+c = ClassificationProject(args.project_dir)
+
+c.model.load_weights(args.weights)
+
+print("Predicting for test sample ...")
+scores_test = c.evaluate(c.x_test)
+print("Done")
+
+fpr, tpr, threshold = roc_curve(c.y_test, scores_test, sample_weight = c.w_test)
+fpr = 1.0 - fpr
+try:
+    roc_auc = auc(tpr, fpr, reorder=True)
+except ValueError:
+    logger.warning("Got a value error from auc - trying to rerun with reorder=True")
+    roc_auc = auc(tpr, fpr, reorder=True)
+
+plt.grid(color='gray', linestyle='--', linewidth=1)
+plt.plot(tpr,  fpr, label=str(c.name + " (AUC = {})".format(roc_auc)))
+plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
+plt.ylabel("Background rejection")
+plt.xlabel("Signal efficiency")
+plt.title('Receiver operating characteristic')
+plt.xlim(0,1)
+plt.ylim(0,1)
+plt.xticks(np.arange(0,1,0.1))
+plt.yticks(np.arange(0,1,0.1))
+plt.legend(loc='lower left', framealpha=1.0)
+plt.savefig(args.plot_prefix+"_ROC.pdf")
+plt.clf()
+
-- 
GitLab