Skip to content
Snippets Groups Projects
Commit cbdbdefc authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Adding a few functions to compare setups

parent 9e229cfd
No related branches found
No related tags found
No related merge requests found
......@@ -3,3 +3,4 @@ setup.sh
run.py
*.swp
*.pyc
*.pdf
#!/usr/bin/env python
import logging
logger = logging.getLogger(__name__)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from toolkit import KerasROOTClassification
"""
A few functions to compare different setups
"""
def overlay_ROC(filename, *projects):
logger.info("Overlay ROC curves for {}".format([p.name for p in projects]))
for p in projects:
fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test)
fpr = 1.0 - fpr
plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(tpr, fpr, label=p.name)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
plt.legend(loc='lower left', framealpha=1.0)
plt.title('Receiver operating characteristic')
plt.ylabel("Background rejection")
plt.xlabel("Signal efficiency")
plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
plt.savefig(filename)
plt.clf()
def overlay_loss(filename, *projects):
logger.info("Overlay loss curves for {}".format([p.name for p in projects]))
for p in projects:
plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name)
plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
plt.savefig(filename)
plt.clf()
if __name__ == "__main__":
import os
logging.basicConfig()
#logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
data_options = dict(signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
bkg_trees = [(filename, "ttbar_NoSys"),
(filename, "wjets_Sherpa221_NoSys")
],
selection="lep1Pt<5000", # cut out a few very weird outliers
branches = ["met", "mt"],
weight_expr = "eventWeight*genWeight",
identifiers = ["DatasetNumber", "EventNumber"],
step_bkg = 100)
example1 = KerasROOTClassification("test_sgd",
optimizer="SGD",
optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
**data_options)
example2 = KerasROOTClassification("test_adam",
optimizer="Adam",
**data_options)
if not os.path.exists("outputs/test_sgd/scores_test.h5"):
example1.train(epochs=20)
if not os.path.exists("outputs/test_adam/scores_test.h5"):
example2.train(epochs=20)
overlay_ROC("overlay_ROC.pdf", example1, example2)
overlay_loss("overlay_loss.pdf", example1, example2)
......@@ -23,7 +23,6 @@ from keras.models import model_from_json
from keras.callbacks import History
from keras.optimizers import SGD
import keras.optimizers
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment