From cbdbdefc1c1f63901a38f91f3b473b9e3f042c46 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 30 Apr 2018 11:52:41 +0200
Subject: [PATCH] Adding a few functions to compare setups

---
 .gitignore |  1 +
 compare.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 toolkit.py |  1 -
 3 files changed, 92 insertions(+), 1 deletion(-)
 create mode 100755 compare.py

diff --git a/.gitignore b/.gitignore
index 1c28bd7..38538a2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,3 +3,4 @@ setup.sh
 run.py
 *.swp
 *.pyc
+*.pdf
diff --git a/compare.py b/compare.py
new file mode 100755
index 0000000..5a34371
--- /dev/null
+++ b/compare.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+
+import logging
+logger = logging.getLogger(__name__)
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import roc_curve, auc
+
+from toolkit import KerasROOTClassification
+
+"""
+A few functions to compare different setups
+"""
+
+def overlay_ROC(filename, *projects):
+
+    logger.info("Overlay ROC curves for {}".format([p.name for p in projects]))
+
+    for p in projects:
+        fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test)
+        fpr = 1.0 - fpr
+
+        plt.grid(color='gray', linestyle='--', linewidth=1)
+        plt.plot(tpr,  fpr, label=p.name)
+
+    plt.xlim(0,1)
+    plt.ylim(0,1)
+    plt.xticks(np.arange(0,1,0.1))
+    plt.yticks(np.arange(0,1,0.1))
+    plt.legend(loc='lower left', framealpha=1.0)
+    plt.title('Receiver operating characteristic')
+    plt.ylabel("Background rejection")
+    plt.xlabel("Signal efficiency")
+    plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
+    plt.savefig(filename)
+    plt.clf()
+
+def overlay_loss(filename, *projects):
+
+    logger.info("Overlay loss curves for {}".format([p.name for p in projects]))
+
+    for p in projects:
+        plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name)
+        plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name)
+    plt.ylabel('loss')
+    plt.xlabel('epoch')
+    plt.legend(loc='upper right')
+    plt.savefig(filename)
+    plt.clf()
+
+
+
+if __name__ == "__main__":
+
+    import os
+
+    logging.basicConfig()
+    #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
+    logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
+
+    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
+
+    data_options = dict(signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
+                        bkg_trees = [(filename, "ttbar_NoSys"),
+                                     (filename, "wjets_Sherpa221_NoSys")
+                        ],
+                        selection="lep1Pt<5000", # cut out a few very weird outliers
+                        branches = ["met", "mt"],
+                        weight_expr = "eventWeight*genWeight",
+                        identifiers = ["DatasetNumber", "EventNumber"],
+                        step_bkg = 100)
+
+    example1 = KerasROOTClassification("test_sgd",
+                                       optimizer="SGD",
+                                       optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
+                                       **data_options)
+
+    example2 = KerasROOTClassification("test_adam",
+                                       optimizer="Adam",
+                                       **data_options)
+
+
+    if not os.path.exists("outputs/test_sgd/scores_test.h5"):
+        example1.train(epochs=20)
+
+    if not os.path.exists("outputs/test_adam/scores_test.h5"):
+        example2.train(epochs=20)
+
+    overlay_ROC("overlay_ROC.pdf", example1, example2)
+    overlay_loss("overlay_loss.pdf", example1, example2)
diff --git a/toolkit.py b/toolkit.py
index bd06cdc..72a4228 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -23,7 +23,6 @@ from keras.models import model_from_json
 from keras.callbacks import History
 from keras.optimizers import SGD
 import keras.optimizers
-import matplotlib.pyplot as plt
 
 import matplotlib.pyplot as plt
 
-- 
GitLab