Skip to content
Snippets Groups Projects
toolkit.py 5.76 KiB
Newer Older
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
#!/usr/bin/env python

import os

import logging
logger = logging.getLogger("KerasROOTClassification")
logger.addHandler(logging.NullHandler())

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
from root_numpy import tree2array, rec2array
import numpy as np
import pandas as pd
import h5py

import ROOT

class KerasROOTClassification:

    dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    def __init__(self, name,
                 signal_trees, bkg_trees, branches, weight_expr, identifiers,
                 layers=3, nodes=64, out_dir="./outputs"):
        self.name = name
        self.signal_trees = signal_trees
        self.bkg_trees = bkg_trees
        self.branches = branches
        self.weight_expr = weight_expr
        self.identifiers = identifiers
        self.layers = layers
        self.nodes = nodes
        self.out_dir = out_dir

        self.project_dir = os.path.join(self.out_dir, name)

        if not os.path.exists(self.out_dir):
            os.mkdir(self.out_dir)

        if not os.path.exists(self.project_dir):
            os.mkdir(self.project_dir)

        self.s_train = None
        self.b_train = None
        self.s_test = None
        self.b_test = None
        self.x_train = None
        self.x_test = None
        self.y_train = None
        self.y_test = None

        self.s_eventlist_train = None
        self.b_eventlist_train = None

    def load_data(self):

        try:

            self._load_from_hdf5()

        except KeyError:

            logger.info("Couldn't load all datasets - reading from ROOT trees")

            # Read signal and background trees into structured numpy arrays
            signal_chain = ROOT.TChain()
            bkg_chain = ROOT.TChain()
            for filename, treename in self.signal_trees:
                signal_chain.AddFile(filename, -1, treename)
            for filename, treename in self.bkg_trees:
                bkg_chain.AddFile(filename, -1, treename)
            self.s_train = tree2array(signal_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
            self.b_train = tree2array(bkg_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
            self.s_test = tree2array(signal_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
            self.b_test = tree2array(bkg_chain, branches=self.branches+[self.weight_expr], start=1, step=2)

            self._dump_training_list()
            self.s_eventlist_train = self.s_train[self.identifiers]
            self.b_eventlist_train = self.b_train[self.identifiers]

            # now we don't need the identifiers anymore
            self.s_train = self.s_train[self.branches+[self.weight_expr]]
            self.b_train = self.b_train[self.branches+[self.weight_expr]]

            # create x (input), y (target) and w (weights) arrays
            # the first block will be signals, the second block backgrounds
            self.x_train = rec2array(self.s_train[self.branches])
            self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
            self.x_test = rec2array(self.s_test[self.branches])
            self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
            self.w_train = self.s_train[self.weight_expr]
            self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
            self.w_test = self.s_test[self.weight_expr]
            self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))

            self.y_train = np.empty(len(self.x_train))
            self.y_train[:len(self.s_train)] = 1
            self.y_train[len(self.s_train):] = 0
            self.y_test = np.empty(len(self.x_test))
            self.y_test[:len(self.s_test)] = 1
            self.y_test[len(self.s_test):] = 0

            logger.info("Writing to hdf5")
            self._dump_to_hdf5()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed


    def _dump_training_list(self):
        s_eventlist = pd.DataFrame(self.s_train[self.identifiers])
        b_eventlist = pd.DataFrame(self.b_train[self.identifiers])

        s_eventlist.to_csv(os.path.join(self.project_dir, "s_eventlist_train.csv"))
        s_eventlist.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))


    def _dump_to_hdf5(self):
        for dataset_name in self.dataset_names:
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf:
                hf.create_dataset(dataset_name, data=getattr(self, dataset_name))


    def _load_from_hdf5(self):
        for dataset_name in self.dataset_names:
            filename = os.path.join(self.project_dir, dataset_name+".h5")
            logger.info("Trying to load {} from {}".format(dataset_name, filename))
            with h5py.File(filename) as hf:
                setattr(self, dataset_name,  hf[dataset_name][:])
        logger.info("Data loaded")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed


    def train(self):
        pass

    def evaluate(self):
        pass

    def writeFriendTree(self):
        pass

    def plotROC(self):
        pass

    def plotScore(self):
        pass



if __name__ == "__main__":

    logging.basicConfig()
    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = KerasROOTClassification("test",
                                signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                                bkg_trees = [(filename, "ttbar_NoSys"),
                                             (filename, "wjets_Sherpa221_NoSys")
                                ],
                                branches = ["met", "mt"],
                                weight_expr = "eventWeight*genWeight",
                                identifiers = ["DatasetNumber", "EventNumber"])

    c.load_data()
    print(c.x_train)
    print(len(c.x_train))