From 2d10bf67fe3e54dca0d302b6934dc3a6cfda943b Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Thu, 26 Apr 2018 13:41:07 +0200
Subject: [PATCH] load data from hdf5 if already existent

---
 toolkit.py | 103 +++++++++++++++++++++++++++++++++--------------------
 1 file changed, 64 insertions(+), 39 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index d169f4c..0240e19 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -2,6 +2,10 @@
 
 import os
 
+import logging
+logger = logging.getLogger("KerasROOTClassification")
+logger.addHandler(logging.NullHandler())
+
 from root_numpy import tree2array, rec2array
 import numpy as np
 import pandas as pd
@@ -11,6 +15,8 @@ import ROOT
 
 class KerasROOTClassification:
 
+    dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
+
     def __init__(self, name,
                  signal_trees, bkg_trees, branches, weight_expr, identifiers,
                  layers=3, nodes=64, out_dir="./outputs"):
@@ -46,38 +52,54 @@ class KerasROOTClassification:
 
     def load_data(self):
 
-        # Read signal and background trees into structured numpy arrays
-        signal_chain = ROOT.TChain()
-        bkg_chain = ROOT.TChain()
-        for filename, treename in self.signal_trees:
-            signal_chain.AddFile(filename, -1, treename)
-        for filename, treename in self.bkg_trees:
-            bkg_chain.AddFile(filename, -1, treename)
-        self.s_train = tree2array(signal_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
-        self.b_train = tree2array(bkg_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
-        self.s_test = tree2array(signal_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
-        self.b_test = tree2array(bkg_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
-
-        self._dump_training_list()
-        self.s_eventlist_train = self.s_train[self.identifiers]
-        self.b_eventlist_train = self.b_train[self.identifiers]
-
-        # now we don't need the identifiers anymore
-        self.s_train = self.s_train[self.branches+[self.weight_expr]]
-        self.b_train = self.b_train[self.branches+[self.weight_expr]]
-
-        # create x (input), y (target) and w (weights) arrays
-        # the first block will be signals, the second block backgrounds
-        self.x_train = rec2array(self.s_train[self.branches])
-        self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
-        self.x_test = rec2array(self.s_test[self.branches])
-        self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
-        self.w_train = self.s_train[self.weight_expr]
-        self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
-        self.w_test = self.s_test[self.weight_expr]
-        self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))
-
-        self._dump_to_hdf5()
+        try:
+
+            self._load_from_hdf5()
+
+        except KeyError:
+
+            logger.info("Couldn't load all datasets - reading from ROOT trees")
+
+            # Read signal and background trees into structured numpy arrays
+            signal_chain = ROOT.TChain()
+            bkg_chain = ROOT.TChain()
+            for filename, treename in self.signal_trees:
+                signal_chain.AddFile(filename, -1, treename)
+            for filename, treename in self.bkg_trees:
+                bkg_chain.AddFile(filename, -1, treename)
+            self.s_train = tree2array(signal_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
+            self.b_train = tree2array(bkg_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
+            self.s_test = tree2array(signal_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
+            self.b_test = tree2array(bkg_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
+
+            self._dump_training_list()
+            self.s_eventlist_train = self.s_train[self.identifiers]
+            self.b_eventlist_train = self.b_train[self.identifiers]
+
+            # now we don't need the identifiers anymore
+            self.s_train = self.s_train[self.branches+[self.weight_expr]]
+            self.b_train = self.b_train[self.branches+[self.weight_expr]]
+
+            # create x (input), y (target) and w (weights) arrays
+            # the first block will be signals, the second block backgrounds
+            self.x_train = rec2array(self.s_train[self.branches])
+            self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
+            self.x_test = rec2array(self.s_test[self.branches])
+            self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
+            self.w_train = self.s_train[self.weight_expr]
+            self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
+            self.w_test = self.s_test[self.weight_expr]
+            self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))
+
+            self.y_train = np.empty(len(self.x_train))
+            self.y_train[:len(self.s_train)] = 1
+            self.y_train[len(self.s_train):] = 0
+            self.y_test = np.empty(len(self.x_test))
+            self.y_test[:len(self.s_test)] = 1
+            self.y_test[len(self.s_test):] = 0
+
+            logger.info("Writing to hdf5")
+            self._dump_to_hdf5()
 
 
     def _dump_training_list(self):
@@ -89,18 +111,18 @@ class KerasROOTClassification:
 
 
     def _dump_to_hdf5(self):
-        for dataset_name in ["x_train", "x_test"]:
+        for dataset_name in self.dataset_names:
             with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf:
                 hf.create_dataset(dataset_name, data=getattr(self, dataset_name))
 
 
     def _load_from_hdf5(self):
-        dataset_names = ["x_train", "x_test"]
-
-
-        # example:
-        with h5py.File("x_test.h5") as hf:
-            self.x_test = hf["x_test"][:]
+        for dataset_name in self.dataset_names:
+            filename = os.path.join(self.project_dir, dataset_name+".h5")
+            logger.info("Trying to load {} from {}".format(dataset_name, filename))
+            with h5py.File(filename) as hf:
+                setattr(self, dataset_name,  hf[dataset_name][:])
+        logger.info("Data loaded")
 
 
     def train(self):
@@ -122,6 +144,9 @@ class KerasROOTClassification:
 
 if __name__ == "__main__":
 
+    logging.basicConfig()
+    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
+
     filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
 
     c = KerasROOTClassification("test",
-- 
GitLab