Skip to content
Snippets Groups Projects
Commit 2d10bf67 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

load data from hdf5 if already existent

parent 9ca36850
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
import os import os
import logging
logger = logging.getLogger("KerasROOTClassification")
logger.addHandler(logging.NullHandler())
from root_numpy import tree2array, rec2array from root_numpy import tree2array, rec2array
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -11,6 +15,8 @@ import ROOT ...@@ -11,6 +15,8 @@ import ROOT
class KerasROOTClassification: class KerasROOTClassification:
dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
def __init__(self, name, def __init__(self, name,
signal_trees, bkg_trees, branches, weight_expr, identifiers, signal_trees, bkg_trees, branches, weight_expr, identifiers,
layers=3, nodes=64, out_dir="./outputs"): layers=3, nodes=64, out_dir="./outputs"):
...@@ -46,38 +52,54 @@ class KerasROOTClassification: ...@@ -46,38 +52,54 @@ class KerasROOTClassification:
def load_data(self): def load_data(self):
# Read signal and background trees into structured numpy arrays try:
signal_chain = ROOT.TChain()
bkg_chain = ROOT.TChain() self._load_from_hdf5()
for filename, treename in self.signal_trees:
signal_chain.AddFile(filename, -1, treename) except KeyError:
for filename, treename in self.bkg_trees:
bkg_chain.AddFile(filename, -1, treename) logger.info("Couldn't load all datasets - reading from ROOT trees")
self.s_train = tree2array(signal_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
self.b_train = tree2array(bkg_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2) # Read signal and background trees into structured numpy arrays
self.s_test = tree2array(signal_chain, branches=self.branches+[self.weight_expr], start=1, step=2) signal_chain = ROOT.TChain()
self.b_test = tree2array(bkg_chain, branches=self.branches+[self.weight_expr], start=1, step=2) bkg_chain = ROOT.TChain()
for filename, treename in self.signal_trees:
self._dump_training_list() signal_chain.AddFile(filename, -1, treename)
self.s_eventlist_train = self.s_train[self.identifiers] for filename, treename in self.bkg_trees:
self.b_eventlist_train = self.b_train[self.identifiers] bkg_chain.AddFile(filename, -1, treename)
self.s_train = tree2array(signal_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
# now we don't need the identifiers anymore self.b_train = tree2array(bkg_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
self.s_train = self.s_train[self.branches+[self.weight_expr]] self.s_test = tree2array(signal_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
self.b_train = self.b_train[self.branches+[self.weight_expr]] self.b_test = tree2array(bkg_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
# create x (input), y (target) and w (weights) arrays self._dump_training_list()
# the first block will be signals, the second block backgrounds self.s_eventlist_train = self.s_train[self.identifiers]
self.x_train = rec2array(self.s_train[self.branches]) self.b_eventlist_train = self.b_train[self.identifiers]
self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
self.x_test = rec2array(self.s_test[self.branches]) # now we don't need the identifiers anymore
self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches]))) self.s_train = self.s_train[self.branches+[self.weight_expr]]
self.w_train = self.s_train[self.weight_expr] self.b_train = self.b_train[self.branches+[self.weight_expr]]
self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
self.w_test = self.s_test[self.weight_expr] # create x (input), y (target) and w (weights) arrays
self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr])) # the first block will be signals, the second block backgrounds
self.x_train = rec2array(self.s_train[self.branches])
self._dump_to_hdf5() self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
self.x_test = rec2array(self.s_test[self.branches])
self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
self.w_train = self.s_train[self.weight_expr]
self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
self.w_test = self.s_test[self.weight_expr]
self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))
self.y_train = np.empty(len(self.x_train))
self.y_train[:len(self.s_train)] = 1
self.y_train[len(self.s_train):] = 0
self.y_test = np.empty(len(self.x_test))
self.y_test[:len(self.s_test)] = 1
self.y_test[len(self.s_test):] = 0
logger.info("Writing to hdf5")
self._dump_to_hdf5()
def _dump_training_list(self): def _dump_training_list(self):
...@@ -89,18 +111,18 @@ class KerasROOTClassification: ...@@ -89,18 +111,18 @@ class KerasROOTClassification:
def _dump_to_hdf5(self): def _dump_to_hdf5(self):
for dataset_name in ["x_train", "x_test"]: for dataset_name in self.dataset_names:
with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf: with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf:
hf.create_dataset(dataset_name, data=getattr(self, dataset_name)) hf.create_dataset(dataset_name, data=getattr(self, dataset_name))
def _load_from_hdf5(self): def _load_from_hdf5(self):
dataset_names = ["x_train", "x_test"] for dataset_name in self.dataset_names:
filename = os.path.join(self.project_dir, dataset_name+".h5")
logger.info("Trying to load {} from {}".format(dataset_name, filename))
# example: with h5py.File(filename) as hf:
with h5py.File("x_test.h5") as hf: setattr(self, dataset_name, hf[dataset_name][:])
self.x_test = hf["x_test"][:] logger.info("Data loaded")
def train(self): def train(self):
...@@ -122,6 +144,9 @@ class KerasROOTClassification: ...@@ -122,6 +144,9 @@ class KerasROOTClassification:
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root" filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
c = KerasROOTClassification("test", c = KerasROOTClassification("test",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment