From 12b8426733f632a7c32e7dbe98516875308bb7a1 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Fri, 27 Apr 2018 16:54:07 +0200
Subject: [PATCH] scores managed as properties (saved and loaded from h5)

---
 toolkit.py | 90 ++++++++++++++++++++++++++++++++++++++----------------
 1 file changed, 64 insertions(+), 26 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 4fce688..3db07b0 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -36,7 +36,7 @@ K.set_session(session)
 
 import ROOT
 
-class KerasROOTClassification:
+class KerasROOTClassification(object):
 
 
     dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
@@ -51,7 +51,9 @@ class KerasROOTClassification:
                  validation_split=0.33,
                  activation_function='relu',
                  out_dir="./outputs",
-                 scaler_type="RobustScaler"):
+                 scaler_type="RobustScaler",
+                 step_signal=2,
+                 step_bkg=2):
         self.name = name
         self.signal_trees = signal_trees
         self.bkg_trees = bkg_trees
@@ -66,6 +68,8 @@ class KerasROOTClassification:
         self.activation_function = activation_function
         self.out_dir = out_dir
         self.scaler_type = scaler_type
+        self.step_signal = step_signal
+        self.step_bkg = step_bkg
 
         self.project_dir = os.path.join(self.out_dir, name)
 
@@ -94,8 +98,8 @@ class KerasROOTClassification:
         self._model = None
         self._history = None
 
-        self.score_train = None
-        self.score_test = None
+        self._scores_train = None
+        self._scores_test = None
 
         # track the number of epochs this model has been trained
         self.total_epochs = 0
@@ -124,19 +128,19 @@ class KerasROOTClassification:
             self.s_train = tree2array(signal_chain,
                                       branches=self.branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
-                                      start=0, step=2)
+                                      start=0, step=self.step_signal)
             self.b_train = tree2array(bkg_chain,
                                       branches=self.branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
-                                      start=0, step=200)
+                                      start=0, step=self.step_bkg)
             self.s_test = tree2array(signal_chain,
                                      branches=self.branches+[self.weight_expr],
                                      selection=self.selection,
-                                     start=1, step=2)
+                                     start=1, step=self.step_signal)
             self.b_test = tree2array(bkg_chain,
                                      branches=self.branches+[self.weight_expr],
                                      selection=self.selection,
-                                     start=1, step=200)
+                                     start=1, step=self.step_bkg)
 
             self._dump_training_list()
             self.s_eventlist_train = self.s_train[self.identifiers]
@@ -178,14 +182,20 @@ class KerasROOTClassification:
         s_eventlist.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))
 
 
-    def _dump_to_hdf5(self):
-        for dataset_name in self.dataset_names:
-            with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf:
+    def _dump_to_hdf5(self, dataset_names=None):
+        if dataset_names is None:
+            dataset_names = self.dataset_names
+        for dataset_name in dataset_names:
+            filename = os.path.join(self.project_dir, dataset_name+".h5")
+            logger.info("Writing {} to {}".format(dataset_name, filename))
+            with h5py.File(filename, "w") as hf:
                 hf.create_dataset(dataset_name, data=getattr(self, dataset_name))
 
 
-    def _load_from_hdf5(self):
-        for dataset_name in self.dataset_names:
+    def _load_from_hdf5(self, dataset_names=None):
+        if dataset_names is None:
+            dataset_names = self.dataset_names
+        for dataset_name in dataset_names:
             filename = os.path.join(self.project_dir, dataset_name+".h5")
             logger.info("Trying to load {} from {}".format(dataset_name, filename))
             with h5py.File(filename) as hf:
@@ -193,6 +203,33 @@ class KerasROOTClassification:
         logger.info("Data loaded")
 
 
+    @property
+    def scores_train(self):
+        if self._scores_train is None:
+            self._load_from_hdf5(["_scores_train"])
+        return self._scores_train
+
+
+    @scores_train.setter
+    def scores_train(self, value):
+        self._scores_train = value
+        self._dump_to_hdf5(["_scores_train"])
+
+
+    @property
+    def scores_test(self):
+        if self._scores_test is None:
+            self._load_from_hdf5(["_scores_test"])
+        return self._scores_test
+
+
+    @scores_test.setter
+    def scores_test(self, value):
+        self._scores_test = value
+        logger.info("dump")
+        self._dump_to_hdf5(["_scores_test"])
+
+
     @property
     def scaler(self):
         # create the scaler (and fit to training data) if not existent
@@ -309,13 +346,13 @@ class KerasROOTClassification:
 
         logger.info("Train model")
         self._history = self.model.fit(self.x_train,
-                       # the reshape might be unnescessary here
-                       self.y_train.reshape(-1, 1),
-                       epochs=epochs,
-                       validation_split = self.validation_split,
-                       class_weight=self.class_weight,
-                       shuffle=True,
-                       batch_size=self.batch_size)
+                                       # the reshape might be unnescessary here
+                                       self.y_train.reshape(-1, 1),
+                                       epochs=epochs,
+                                       validation_split = self.validation_split,
+                                       class_weight=self.class_weight,
+                                       shuffle=True,
+                                       batch_size=self.batch_size)
 
         logger.info("Save weights")
         self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
@@ -323,7 +360,7 @@ class KerasROOTClassification:
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
 
-        logger.info("Create scores for ROC curve")
+        logger.info("Create/Update scores for ROC curve")
         self.scores_test = self.model.predict(self.x_test)
         self.scores_train = self.model.predict(self.x_train)
 
@@ -454,12 +491,12 @@ class KerasROOTClassification:
 if __name__ == "__main__":
 
     logging.basicConfig()
-    #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
-    logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
+    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
+    #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
 
     filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
 
-    c = KerasROOTClassification("test2",
+    c = KerasROOTClassification("test3",
                                 signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                                 bkg_trees = [(filename, "ttbar_NoSys"),
                                              (filename, "wjets_Sherpa221_NoSys")
@@ -467,9 +504,10 @@ if __name__ == "__main__":
                                 selection="lep1Pt<5000", # cut out a few very weird outliers
                                 branches = ["met", "mt"],
                                 weight_expr = "eventWeight*genWeight",
-                                identifiers = ["DatasetNumber", "EventNumber"])
+                                identifiers = ["DatasetNumber", "EventNumber"],
+                                step_bkg = 100)
 
-    c.train(epochs=20)
+#    c.train(epochs=10)
     c.plot_ROC()
     c.plot_loss()
     c.plot_accuracy()
-- 
GitLab