#!/usr/bin/env python import logging logger = logging.getLogger(__name__) import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc from .toolkit import ClassificationProject """ A few functions to compare different setups """ def overlay_ROC(filename, *projects, **kwargs): xlim = kwargs.pop("xlim", (0,1)) ylim = kwargs.pop("ylim", (0,1)) plot_thresholds = kwargs.pop("plot_thresholds", False) threshold_log = kwargs.pop("threshold_log", True) if kwargs: raise KeyError("Unknown kwargs: {}".format(kwargs)) logger.info("Overlay ROC curves for {}".format([p.name for p in projects])) fig, ax = plt.subplots() if plot_thresholds: ax2 = ax.twinx() ax2.set_ylabel("Thresholds") if threshold_log: ax2.set_yscale("log") prop_cycle = plt.rcParams['axes.prop_cycle'] colors = prop_cycle.by_key()['color'] for p, color in zip(projects, colors): fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test) fpr = 1.0 - fpr try: roc_auc = auc(tpr, fpr) except ValueError: logger.warning("Got a value error from auc - trying to rerun with reorder=True") roc_auc = auc(tpr, fpr, reorder=True) ax.grid(color='gray', linestyle='--', linewidth=1) ax.plot(tpr, fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color) if plot_thresholds: ax2.plot(tpr, threshold, "--", color=color) if xlim is not None: ax.set_xlim(*xlim) if ylim is not None: ax.set_ylim(*ylim) # plt.xticks(np.arange(0,1,0.1)) # plt.yticks(np.arange(0,1,0.1)) ax.legend(loc='lower left', framealpha=1.0) ax.set_title('Receiver operating characteristic') ax.set_ylabel("Background rejection") ax.set_xlabel("Signal efficiency") if plot_thresholds: # to fit right y-axis description fig.tight_layout() fig.savefig(filename) plt.close(fig) def overlay_loss(filename, *projects, **kwargs): xlim = kwargs.pop("xlim", None) ylim = kwargs.pop("ylim", None) log = kwargs.pop("log", False) if kwargs: raise KeyError("Unknown kwargs: {}".format(kwargs)) logger.info("Overlay loss curves for {}".format([p.name for p in projects])) prop_cycle = plt.rcParams['axes.prop_cycle'] colors = prop_cycle.by_key()['color'] for p,color in zip(projects,colors): hist_dict = p.csv_hist plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) plt.ylabel('loss') plt.xlabel('epoch') if log: plt.yscale("log") if xlim is not None: plt.xlim(*xlim) if ylim is not None: plt.ylim(*ylim) plt.legend(loc='upper right') plt.savefig(filename) plt.clf() if __name__ == "__main__": import os logging.basicConfig() #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG) filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root" data_options = dict(signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")], bkg_trees = [(filename, "ttbar_NoSys"), (filename, "wjets_Sherpa221_NoSys") ], selection="lep1Pt<5000", # cut out a few very weird outliers branches = ["met", "mt"], weight_expr = "eventWeight*genWeight", identifiers = ["DatasetNumber", "EventNumber"], step_bkg = 100) example1 = ClassificationProject("test_sgd", optimizer="SGD", optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9), **data_options) example2 = ClassificationProject("test_adam", optimizer="Adam", **data_options) if not os.path.exists("outputs/test_sgd/scores_test.h5"): example1.train(epochs=20) if not os.path.exists("outputs/test_adam/scores_test.h5"): example2.train(epochs=20) overlay_ROC("overlay_ROC.pdf", example1, example2) overlay_loss("overlay_loss.pdf", example1, example2)