Skip to content
Snippets Groups Projects
Commit cbdbdefc authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Adding a few functions to compare setups

parent 9e229cfd
No related branches found
No related tags found
No related merge requests found
...@@ -3,3 +3,4 @@ setup.sh ...@@ -3,3 +3,4 @@ setup.sh
run.py run.py
*.swp *.swp
*.pyc *.pyc
*.pdf
#!/usr/bin/env python
import logging
logger = logging.getLogger(__name__)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from toolkit import KerasROOTClassification
"""
A few functions to compare different setups
"""
def overlay_ROC(filename, *projects):
logger.info("Overlay ROC curves for {}".format([p.name for p in projects]))
for p in projects:
fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test)
fpr = 1.0 - fpr
plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(tpr, fpr, label=p.name)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
plt.legend(loc='lower left', framealpha=1.0)
plt.title('Receiver operating characteristic')
plt.ylabel("Background rejection")
plt.xlabel("Signal efficiency")
plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
plt.savefig(filename)
plt.clf()
def overlay_loss(filename, *projects):
logger.info("Overlay loss curves for {}".format([p.name for p in projects]))
for p in projects:
plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name)
plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
plt.savefig(filename)
plt.clf()
if __name__ == "__main__":
import os
logging.basicConfig()
#logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
data_options = dict(signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
bkg_trees = [(filename, "ttbar_NoSys"),
(filename, "wjets_Sherpa221_NoSys")
],
selection="lep1Pt<5000", # cut out a few very weird outliers
branches = ["met", "mt"],
weight_expr = "eventWeight*genWeight",
identifiers = ["DatasetNumber", "EventNumber"],
step_bkg = 100)
example1 = KerasROOTClassification("test_sgd",
optimizer="SGD",
optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
**data_options)
example2 = KerasROOTClassification("test_adam",
optimizer="Adam",
**data_options)
if not os.path.exists("outputs/test_sgd/scores_test.h5"):
example1.train(epochs=20)
if not os.path.exists("outputs/test_adam/scores_test.h5"):
example2.train(epochs=20)
overlay_ROC("overlay_ROC.pdf", example1, example2)
overlay_loss("overlay_loss.pdf", example1, example2)
...@@ -23,7 +23,6 @@ from keras.models import model_from_json ...@@ -23,7 +23,6 @@ from keras.models import model_from_json
from keras.callbacks import History from keras.callbacks import History
from keras.optimizers import SGD from keras.optimizers import SGD
import keras.optimizers import keras.optimizers
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment