Skip to content
Snippets Groups Projects
compare.py 5.41 KiB
Newer Older
#!/usr/bin/env python

import logging
logger = logging.getLogger(__name__)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

from .toolkit import ClassificationProject
from .plotting import save_show

"""
A few functions to compare different setups
"""

def overlay_ROC(filename, *projects, **kwargs):

    xlim = kwargs.pop("xlim", (0,1))
    ylim = kwargs.pop("ylim", (0,1))
    plot_thresholds = kwargs.pop("plot_thresholds", False)
    threshold_log = kwargs.pop("threshold_log", True)
    lumifactor = kwargs.pop("lumifactor", None)
    tight_layout = kwargs.pop("tight_layout", False)
    if kwargs:
        raise KeyError("Unknown kwargs: {}".format(kwargs))

    logger.info("Overlay ROC curves for {}".format([p.name for p in projects]))

    fig, ax = plt.subplots()

    if plot_thresholds:
        ax2 = ax.twinx()
        ax2.set_ylabel("Thresholds")
        if threshold_log:
            ax2.set_yscale("log")

    if lumifactor is not None:
        ax_abs_b = ax.twinx()
        ax_abs_s = ax.twiny()

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']

    for p, color in zip(projects, colors):
        fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test)
        fpr = 1.0 - fpr
        try:
            roc_auc = auc(tpr, fpr)
        except ValueError:
            logger.warning("Got a value error from auc - trying to rerun with reorder=True")
            roc_auc = auc(tpr, fpr, reorder=True)
        ax.grid(color='gray', linestyle='--', linewidth=1)
        ax.plot(tpr,  fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color)
        if plot_thresholds:
            ax2.plot(tpr, threshold, "--", color=color)
        if lumifactor is not None:
            sumw_b = p.w_test[p.y_test==0].sum()*lumifactor
            sumw_s = p.w_test[p.y_test==1].sum()*lumifactor
            ax_abs_b.plot(tpr, (1.-fpr)*sumw_b, alpha=0)
            ax_abs_b.invert_yaxis()
            ax_abs_s.plot(tpr*sumw_s, fpr, alpha=0)
    if xlim is not None:
        ax.set_xlim(*xlim)
    if ylim is not None:
        ax.set_ylim(*ylim)
    if lumifactor is not None:
        ax_abs_b.set_ylim((1-ax.get_ylim()[0])*sumw_b, (1-ax.get_ylim()[1])*sumw_b)
        ax_abs_b.set_xlim(*ax.get_xlim())
        ax_abs_s.set_xlim(ax.get_xlim()[0]*sumw_s, ax.get_xlim()[1]*sumw_s)
        ax_abs_s.set_ylim(*ax.get_ylim())
        ax_abs_b.set_ylabel("Number of background events")
        ax_abs_s.set_xlabel("Number of signal events")
    # plt.xticks(np.arange(0,1,0.1))
    # plt.yticks(np.arange(0,1,0.1))
    ax.legend(loc='lower left', framealpha=1.0)
    if lumifactor is None:
        ax.set_title('Receiver operating characteristic')
    ax.set_ylabel("Background rejection")
    ax.set_xlabel("Signal efficiency")
    if plot_thresholds or tight_layout:
        # to fit right y-axis description
        fig.tight_layout()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    return save_show(plt, fig, filename)
def overlay_loss(filename, *projects, **kwargs):

    xlim = kwargs.pop("xlim", None)
    ylim = kwargs.pop("ylim", None)
    log = kwargs.pop("log", False)
    if kwargs:
        raise KeyError("Unknown kwargs: {}".format(kwargs))

    logger.info("Overlay loss curves for {}".format([p.name for p in projects]))

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']

    for p,color in zip(projects,colors):
        ax.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color)
        ax.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color)
    ax.set_ylabel('loss')
    ax.set_xlabel('epoch')
        ax.set_ylim(*ylim)
    ax.legend(loc='upper right')
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    return save_show(plt, fig, filename)



if __name__ == "__main__":

    import os

    logging.basicConfig()
    #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)

    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    data_options = dict(signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                        bkg_trees = [(filename, "ttbar_NoSys"),
                                     (filename, "wjets_Sherpa221_NoSys")
                        ],
                        selection="lep1Pt<5000", # cut out a few very weird outliers
                        branches = ["met", "mt"],
                        weight_expr = "eventWeight*genWeight",
                        identifiers = ["DatasetNumber", "EventNumber"],
                        step_bkg = 100)

    example1 = ClassificationProject("test_sgd",
                                     optimizer="SGD",
                                     optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
                                     **data_options)
    example2 = ClassificationProject("test_adam",
                                     optimizer="Adam",
                                     **data_options)


    if not os.path.exists("outputs/test_sgd/scores_test.h5"):
        example1.train(epochs=20)

    if not os.path.exists("outputs/test_adam/scores_test.h5"):
        example2.train(epochs=20)

    overlay_ROC("overlay_ROC.pdf", example1, example2)
    overlay_loss("overlay_loss.pdf", example1, example2)