Skip to content
Snippets Groups Projects
Commit 9e229cfd authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Making loading and storage of data consistent

parent a3ed26cf
No related branches found
No related tags found
No related merge requests found
......@@ -44,8 +44,11 @@ import ROOT
class KerasROOTClassification(object):
dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
# Datasets that are stored to (and dynamically loaded from) hdf5
dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test"]
# Datasets that are retrieved from ROOT trees the first time
dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
def __init__(self, name, *args, **kwargs):
self._init_from_args(name, *args, **kwargs)
......@@ -107,6 +110,8 @@ class KerasROOTClassification(object):
self._y_test = None
self._w_train = None
self._w_test = None
self._scores_train = None
self._scores_test = None
self.s_eventlist_train = None
self.b_eventlist_train = None
......@@ -118,9 +123,6 @@ class KerasROOTClassification(object):
self._model = None
self._history = None
self._scores_train = None
self._scores_test = None
# track the number of epochs this model has been trained
self.total_epochs = 0
......@@ -132,7 +134,8 @@ class KerasROOTClassification(object):
try:
self._load_from_hdf5()
# if those don't exist, we need to load them from ROOT trees first
self._load_from_hdf5(*self.dataset_names_tree)
except KeyError:
......@@ -188,8 +191,7 @@ class KerasROOTClassification(object):
self.y_test[:len(self.s_test)] = 1
self.y_test[len(self.s_test):] = 0
logger.info("Writing to hdf5")
self._dump_to_hdf5()
self._dump_to_hdf5(*self.dataset_names_tree)
self.data_loaded = True
......@@ -202,8 +204,8 @@ class KerasROOTClassification(object):
s_eventlist.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))
def _dump_to_hdf5(self, dataset_names=None):
if dataset_names is None:
def _dump_to_hdf5(self, *dataset_names):
if len(dataset_names) < 1:
dataset_names = self.dataset_names
for dataset_name in dataset_names:
filename = os.path.join(self.project_dir, dataset_name+".h5")
......@@ -212,8 +214,8 @@ class KerasROOTClassification(object):
hf.create_dataset(dataset_name, data=getattr(self, dataset_name))
def _load_from_hdf5(self, dataset_names=None):
if dataset_names is None:
def _load_from_hdf5(self, *dataset_names):
if len(dataset_names) < 1:
dataset_names = self.dataset_names
for dataset_name in dataset_names:
filename = os.path.join(self.project_dir, dataset_name+".h5")
......@@ -223,33 +225,6 @@ class KerasROOTClassification(object):
logger.info("Data loaded")
@property
def scores_train(self):
if self._scores_train is None:
self._load_from_hdf5(["_scores_train"])
return self._scores_train
@scores_train.setter
def scores_train(self, value):
self._scores_train = value
self._dump_to_hdf5(["_scores_train"])
@property
def scores_test(self):
if self._scores_test is None:
self._load_from_hdf5(["_scores_test"])
return self._scores_test
@scores_test.setter
def scores_test(self, value):
self._scores_test = value
logger.info("dump")
self._dump_to_hdf5(["_scores_test"])
@property
def scaler(self):
# create the scaler (and fit to training data) if not existent
......@@ -313,6 +288,7 @@ class KerasROOTClassification(object):
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.data_transformed = True
logger.info("Training and test data transformed")
def _read_info(self, key, default):
......@@ -389,6 +365,9 @@ class KerasROOTClassification(object):
np.random.shuffle(self.y_train)
np.random.set_state(rn_state)
np.random.shuffle(self.w_train)
if self._scores_test is not None:
np.random.set_state(rn_state)
np.random.shuffle(self._scores_test)
def train(self, epochs=10):
......@@ -424,8 +403,6 @@ class KerasROOTClassification(object):
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
print(self.history)
logger.info("Save history")
self._dump_history()
......@@ -439,6 +416,8 @@ class KerasROOTClassification(object):
self.scores_test = self.model.predict(self.x_test)
self.scores_train = self.model.predict(self.x_train)
self._dump_to_hdf5("scores_train", "scores_test")
def evaluate(self):
......@@ -566,13 +545,12 @@ class KerasROOTClassification(object):
def create_getter(dataset_name):
def getx(self):
if getattr(self, "_"+dataset_name) is None:
self._load_from_hdf5([dataset_name])
self._load_from_hdf5(dataset_name)
return getattr(self, "_"+dataset_name)
return getx
def create_setter(dataset_name):
def setx(self, value):
# maybe change this at some point to also dump into hdf
setattr(self, "_"+dataset_name, value)
return setx
......@@ -590,11 +568,14 @@ if __name__ == "__main__":
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
c = KerasROOTClassification("test3",
c = KerasROOTClassification("test4",
signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
bkg_trees = [(filename, "ttbar_NoSys"),
(filename, "wjets_Sherpa221_NoSys")
],
optimizer="SGD",
optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
# optimizer="Adam",
selection="lep1Pt<5000", # cut out a few very weird outliers
branches = ["met", "mt"],
weight_expr = "eventWeight*genWeight",
......@@ -602,7 +583,7 @@ if __name__ == "__main__":
step_bkg = 100)
#c.load()
c.train(epochs=10)
c.train(epochs=20)
c.plot_ROC()
c.plot_loss()
c.plot_accuracy()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment