Skip to content
Snippets Groups Projects
compare.py 2.98 KiB
Newer Older
#!/usr/bin/env python

import logging
logger = logging.getLogger(__name__)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

from toolkit import KerasROOTClassification

"""
A few functions to compare different setups
"""

def overlay_ROC(filename, *projects):

    logger.info("Overlay ROC curves for {}".format([p.name for p in projects]))

    for p in projects:
        fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test)
        fpr = 1.0 - fpr

        plt.grid(color='gray', linestyle='--', linewidth=1)
        plt.plot(tpr,  fpr, label=p.name)

    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.xticks(np.arange(0,1,0.1))
    plt.yticks(np.arange(0,1,0.1))
    plt.legend(loc='lower left', framealpha=1.0)
    plt.title('Receiver operating characteristic')
    plt.ylabel("Background rejection")
    plt.xlabel("Signal efficiency")
    plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
    plt.savefig(filename)
    plt.clf()

def overlay_loss(filename, *projects):

    logger.info("Overlay loss curves for {}".format([p.name for p in projects]))

    for p in projects:
        plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name)
        plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name)
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(loc='upper right')
    plt.savefig(filename)
    plt.clf()



if __name__ == "__main__":

    import os

    logging.basicConfig()
    #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)

    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    data_options = dict(signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                        bkg_trees = [(filename, "ttbar_NoSys"),
                                     (filename, "wjets_Sherpa221_NoSys")
                        ],
                        selection="lep1Pt<5000", # cut out a few very weird outliers
                        branches = ["met", "mt"],
                        weight_expr = "eventWeight*genWeight",
                        identifiers = ["DatasetNumber", "EventNumber"],
                        step_bkg = 100)

    example1 = KerasROOTClassification("test_sgd",
                                       optimizer="SGD",
                                       optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
                                       **data_options)

    example2 = KerasROOTClassification("test_adam",
                                       optimizer="Adam",
                                       **data_options)


    if not os.path.exists("outputs/test_sgd/scores_test.h5"):
        example1.train(epochs=20)

    if not os.path.exists("outputs/test_adam/scores_test.h5"):
        example2.train(epochs=20)

    overlay_ROC("overlay_ROC.pdf", example1, example2)
    overlay_loss("overlay_loss.pdf", example1, example2)