Skip to content
Snippets Groups Projects
toolkit.py 18 KiB
Newer Older
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
#!/usr/bin/env python

import os
import json
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

import logging
logger = logging.getLogger("KerasROOTClassification")
logger.addHandler(logging.NullHandler())

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
from root_numpy import tree2array, rec2array
import numpy as np
import pandas as pd
import h5py
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.externals import joblib
from sklearn.metrics import roc_curve, auc
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

from keras.models import Sequential
from keras.layers import Dense
from keras.models import model_from_json
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt

# configure number of cores
# this doesn't seem to work, but at least with these settings keras only uses 4 processes
import tensorflow as tf
from keras import backend as K
num_cores = 1
config = tf.ConfigProto(intra_op_parallelism_threads=num_cores,
                        inter_op_parallelism_threads=num_cores,
                        allow_soft_placement=True,
                        device_count = {'CPU': num_cores})
session = tf.Session(config=config)
K.set_session(session)

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
import ROOT

class KerasROOTClassification:

    dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    def __init__(self, name,
                 signal_trees, bkg_trees, branches, weight_expr, identifiers,
                 selection=None,
                 layers=3,
                 nodes=64,
                 batch_size=128,
                 validation_split=0.33,
                 activation_function='relu',
                 out_dir="./outputs",
                 scaler_type="RobustScaler"):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        self.name = name
        self.signal_trees = signal_trees
        self.bkg_trees = bkg_trees
        self.branches = branches
        self.weight_expr = weight_expr
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        self.identifiers = identifiers
        self.layers = layers
        self.nodes = nodes
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.activation_function = activation_function
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        self.out_dir = out_dir
        self.scaler_type = scaler_type
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

        self.project_dir = os.path.join(self.out_dir, name)

        if not os.path.exists(self.out_dir):
            os.mkdir(self.out_dir)

        if not os.path.exists(self.project_dir):
            os.mkdir(self.project_dir)

        self.s_train = None
        self.b_train = None
        self.s_test = None
        self.b_test = None
        self.x_train = None
        self.x_test = None
        self.y_train = None
        self.y_test = None

        self.s_eventlist_train = None
        self.b_eventlist_train = None

        self._scaler = None
        self._class_weight = None
        self._bkg_weights = None
        self._sig_weights = None
        self._model = None
        self._history = None
        self.score_train = None
        self.score_test = None

        # track the number of epochs this model has been trained
        self.total_epochs = 0
        self.data_loaded = False
        self.data_transformed = False

    def _load_data(self):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

        try:

            self._load_from_hdf5()

        except KeyError:

            logger.info("Couldn't load all datasets - reading from ROOT trees")

            # Read signal and background trees into structured numpy arrays
            signal_chain = ROOT.TChain()
            bkg_chain = ROOT.TChain()
            for filename, treename in self.signal_trees:
                signal_chain.AddFile(filename, -1, treename)
            for filename, treename in self.bkg_trees:
                bkg_chain.AddFile(filename, -1, treename)
            self.s_train = tree2array(signal_chain,
                                      branches=self.branches+[self.weight_expr]+self.identifiers,
                                      selection=self.selection,
                                      start=0, step=2)
            self.b_train = tree2array(bkg_chain,
                                      branches=self.branches+[self.weight_expr]+self.identifiers,
                                      selection=self.selection,
                                      start=0, step=200)
            self.s_test = tree2array(signal_chain,
                                     branches=self.branches+[self.weight_expr],
                                     selection=self.selection,
                                     start=1, step=2)
            self.b_test = tree2array(bkg_chain,
                                     branches=self.branches+[self.weight_expr],
                                     selection=self.selection,
                                     start=1, step=200)

            self._dump_training_list()
            self.s_eventlist_train = self.s_train[self.identifiers]
            self.b_eventlist_train = self.b_train[self.identifiers]

            # now we don't need the identifiers anymore
            self.s_train = self.s_train[self.branches+[self.weight_expr]]
            self.b_train = self.b_train[self.branches+[self.weight_expr]]

            # create x (input), y (target) and w (weights) arrays
            # the first block will be signals, the second block backgrounds
            self.x_train = rec2array(self.s_train[self.branches])
            self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
            self.x_test = rec2array(self.s_test[self.branches])
            self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
            self.w_train = self.s_train[self.weight_expr]
            self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
            self.w_test = self.s_test[self.weight_expr]
            self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))

            self.y_train = np.empty(len(self.x_train))
            self.y_train[:len(self.s_train)] = 1
            self.y_train[len(self.s_train):] = 0
            self.y_test = np.empty(len(self.x_test))
            self.y_test[:len(self.s_test)] = 1
            self.y_test[len(self.s_test):] = 0

            logger.info("Writing to hdf5")
            self._dump_to_hdf5()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

        self.data_loaded = True

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    def _dump_training_list(self):
        s_eventlist = pd.DataFrame(self.s_train[self.identifiers])
        b_eventlist = pd.DataFrame(self.b_train[self.identifiers])

        s_eventlist.to_csv(os.path.join(self.project_dir, "s_eventlist_train.csv"))
        s_eventlist.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))


    def _dump_to_hdf5(self):
        for dataset_name in self.dataset_names:
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf:
                hf.create_dataset(dataset_name, data=getattr(self, dataset_name))


    def _load_from_hdf5(self):
        for dataset_name in self.dataset_names:
            filename = os.path.join(self.project_dir, dataset_name+".h5")
            logger.info("Trying to load {} from {}".format(dataset_name, filename))
            with h5py.File(filename) as hf:
                setattr(self, dataset_name,  hf[dataset_name][:])
        logger.info("Data loaded")
    @property
    def scaler(self):
        # create the scaler (and fit to training data) if not existent
        if self._scaler is None:
            filename = os.path.join(self.project_dir, "scaler.pkl")
            try:
                self._scaler = joblib.load(filename)
                logger.info("Loaded existing scaler from {}".format(filename))
            except IOError:
                logger.info("Creating new {}".format(self.scaler_type))
                if self.scaler_type == "StandardScaler":
                    self._scaler = StandardScaler()
                elif self.scaler_type == "RobustScaler":
                    self._scaler = RobustScaler()
                else:
                    raise ValueError("Scaler type {} unknown".format(self.scaler_type))
                logger.info("Fitting {} to training data".format(self.scaler_type))
                self._scaler.fit(self.x_train)
                # i think this would refit to test data (and overwrite the parameters)
                # probably we either want to fit only training data or training and test data together
                # logger.info("Fitting StandardScaler to test data")
                # self._scaler.fit(self.x_test)
                joblib.dump(self._scaler, filename)
        return self._scaler


    def _transform_data(self):
        if not self.data_transformed:
            # todo: what to do about the outliers? Where do they come from?
            logger.debug("training data before transformation: {}".format(self.x_train))
            logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
            logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
            self.x_train = self.scaler.transform(self.x_train)
            logger.debug("training data after transformation: {}".format(self.x_train))
            self.x_test = self.scaler.transform(self.x_test)
            self.data_transformed = True


    def _read_info(self, key, default):
        filename = os.path.join(self.project_dir, "info.json")
        if not os.path.exists(filename):
            with open(filename, "w") as of:
                json.dump({}, of)
        with open(filename) as f:
            info = json.load(f)
        return info.get(key, default)
    def _write_info(self, key, value):
        filename = os.path.join(self.project_dir, "info.json")
        with open(filename) as f:
            info = json.load(f)
        info[key] = value
        with open(filename, "w") as of:
            json.dump(info, of)


    @property
    def model(self):
        "Simple MLP"

        if self._model is None:

            self._model = Sequential()

            # first hidden layer
            self._model.add(Dense(self.nodes, input_dim=len(self.branches), activation=self.activation_function))
            # the other hidden layers
            for layer_number in range(self.layers-1):
                self._model.add(Dense(self.nodes, activation=self.activation_function))
            # last layer is one neuron (binary classification)
            self._model.add(Dense(1, activation='sigmoid'))
            logger.info("Compile model")
            self._model.compile(optimizer='SGD',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

            # dump to json for documentation
            with open(os.path.join(self.project_dir, "model.json"), "w") as of:
                of.write(self._model.to_json())

        return self._model

    @property
    def class_weight(self):
        if self._class_weight is None:
            sumw_bkg = np.sum(self.w_train[self.y_train == 0])
            sumw_sig = np.sum(self.w_train[self.y_train == 1])
            self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
        return self._class_weight

    def train(self, epochs=10):

        if not self.data_loaded:
            self._load_data()

        if not self.data_transformed:
            self._transform_data()

        for branch_index, branch in enumerate(self.branches):
            self.plot_input(branch_index)

        try:
            self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
            logger.info("Weights found and loaded")
            logger.info("Continue training")
        except IOError:
            logger.info("No weights found, starting completely new training")

        self.total_epochs = self._read_info("epochs", 0)

        logger.info("Train model")
        self._history = self.model.fit(self.x_train,
                       # the reshape might be unnescessary here
                       self.y_train.reshape(-1, 1),
                       epochs=epochs,
                       validation_split = self.validation_split,
                       class_weight=self.class_weight,
                       shuffle=True,
                       batch_size=self.batch_size)
        logger.info("Save weights")
        self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))

        self.total_epochs += epochs
        self._write_info("epochs", self.total_epochs)

        logger.info("Create scores for ROC curve")
        self.scores_test = self.model.predict(self.x_test)
        self.scores_train = self.model.predict(self.x_train)


Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    def evaluate(self):
        pass

    def write_friend_tree(self):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        pass


    @property
    def bkg_weights(self):
        """
        class weights multiplied by event weights (for plotting)
        TODO: find a better way to do this
        """
        if self._bkg_weights is None:
            logger.debug("Calculating background weights for plotting")
            self._bkg_weights = np.empty(sum(self.y_train == 0))
            self._bkg_weights.fill(self.class_weight[0])
            self._bkg_weights *= self.w_train[self.y_train == 0]
            logger.debug("Background weights: {}".format(self._bkg_weights))
        return self._bkg_weights


    @property
    def sig_weights(self):
        """
        class weights multiplied by event weights (for plotting)
        TODO: find a better way to do this
        """
        if self._sig_weights is None:
            logger.debug("Calculating signal weights for plotting")
            self._sig_weights = np.empty(sum(self.y_train == 1))
            self._sig_weights.fill(self.class_weight[1])
            self._sig_weights *= self.w_train[self.y_train == 1]
            logger.debug("Signal weights: {}".format(self._sig_weights))
        return self._sig_weights


    def plot_input(self, var_index):
        "plot a single input variable"
        branch = self.branches[var_index]
        fig, ax = plt.subplots()
        bkg = self.x_train[:,var_index][self.y_train == 0]
        sig = self.x_train[:,var_index][self.y_train == 1]
        logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
        logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))

        # calculate percentiles to get a heuristic for the range to be plotted
        # (should in principle also be done with weights, but for now do it unweighted)
        range_sig = np.percentile(sig, [1, 99])
        range_bkg = np.percentile(sig, [1, 99])
        plot_range = (min(range_sig[0], range_bkg[0]), max(range_sig[1], range_sig[1]))

        logger.debug("Calculated range based on percentiles: {}".format(plot_range))

        try:
            ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
            ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
        except ValueError:
            # weird, probably not always working workaround for a numpy bug
            plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
            logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
            ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
            ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
        ax.set_xlabel(branch+" (transformed)")
        plot_dir = os.path.join(self.project_dir, "plots")
        if not os.path.exists(plot_dir):
            os.mkdir(plot_dir)
        fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
Eric Schanet's avatar
Eric Schanet committed
        plt.clf()
    def plot_ROC(self):

        logger.info("Plot ROC curve")
        fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
Eric Schanet's avatar
Eric Schanet committed
        fpr = 1.0 - fpr
        roc_auc = auc(tpr, fpr)

        plt.grid(color='gray', linestyle='--', linewidth=1)
        plt.plot(tpr,  fpr, label='NN')
Eric Schanet's avatar
Eric Schanet committed
        plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
Eric Schanet's avatar
Eric Schanet committed
        plt.ylabel("Background rejection")
        plt.xlabel("Signal efficiency")
        plt.title('Receiver operating characteristic')
        plt.xlim(0,1)
        plt.ylim(0,1)
        plt.xticks(np.arange(0,1,0.1))
        plt.yticks(np.arange(0,1,0.1))
        plt.legend(loc='lower left', framealpha=1.0)
Thomas Weber's avatar
Thomas Weber committed
        plt.text(0.21,0.02,"AUC: {}".format(str(roc_auc)), size=12)
        plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
        plt.clf()

    def plot_score(self):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        pass

    
    def plot_loss(self):

        logger.info("Plot losses")
        plt.plot(self._history.history['loss'])
        plt.plot(self._history.history['val_loss'])
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train','test'], loc='upper left')
        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
    

    def plot_accuracy(self):
        
        logger.info("Plot accuracy")
        plt.plot(self._history.history['acc'])
        plt.plot(self._history.history['val_acc'])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'test'], loc='upper left')
        plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

if __name__ == "__main__":

    logging.basicConfig()
    #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = KerasROOTClassification("test2",
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
                                signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                                bkg_trees = [(filename, "ttbar_NoSys"),
                                             (filename, "wjets_Sherpa221_NoSys")
                                ],
                                selection="lep1Pt<5000", # cut out a few very weird outliers
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
                                branches = ["met", "mt"],
                                weight_expr = "eventWeight*genWeight",
                                identifiers = ["DatasetNumber", "EventNumber"])

    c.train(epochs=20)
    c.plot_ROC()
    c.plot_loss()
    c.plot_accuracy()