From a592516c8c87324ff257cb4134d0b93b389c11d2 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Mon, 18 Jun 2018 13:22:24 +0200
Subject: [PATCH] try with more appending/replacing

---
 toolkit.py | 82 ++++++++++++++++++++++++++----------------------------
 1 file changed, 40 insertions(+), 42 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index d2120e0..5d8e74d 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -159,7 +159,7 @@ class ClassificationProject(object):
         else:
             # otherwise initialise new project
             self._init_from_args(name, *args, **kwargs)
-            with open(os.path.join(self.project_dir, "options.pickle"), "w") as of:
+            with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of:
                 pickle.dump(dict(args=args, kwargs=kwargs), of)
 
 
@@ -275,6 +275,38 @@ class ClassificationProject(object):
         self.is_training = False
 
 
+    def _append_from_tree(self, chain, step, is_train, is_signal):
+        ar = tree2array(chain,
+                        branches=self.branches+[self.weight_expr]+self.identifiers,
+                        selection=self.selection,
+                        start=0, step=step)
+        if is_train:
+            eventlist = ar[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
+            if is_signal:
+                self.s_eventlist_train = eventlist
+            else:
+                self.b_eventlist_train = eventlist
+
+        # now we don't need the identifiers anymore
+        ar = ar[self.branches+[self.weight_expr]]
+
+        trainstring = "train" if is_train else "test"
+        signal_fun = np.ones if is_signal else np.zeros
+
+        setattr(self, "w_"+trainstring, np.concatenate([getattr(self, "w_"+trainstring), ar[self.weight_expr]]))
+        setattr(self, "x_"+trainstring, np.concatenate([getattr(self, "x_"+trainstring), rec2array(ar[self.branches])]))
+        setattr(self, "y_"+trainstring, np.concatenate([getattr(self, "y_"+trainstring), signal_fun(len(ar))]))
+
+        del ar
+
+
+    def _reset_data(self):
+        for dataset_name in self.dataset_names:
+            setattr(self, dataset_name, np.array([]))
+        self.x_train = self.x_train.reshape(0, len(self.branches))
+        self.x_test = self.x_test.reshape(0, len(self.branches))
+
+
     def _load_data(self):
 
         try:
@@ -286,6 +318,8 @@ class ClassificationProject(object):
 
             logger.info("Couldn't load all datasets - reading from ROOT trees")
 
+            self._reset_data()
+
             # Read signal and background trees into structured numpy arrays
             signal_chain = ROOT.TChain()
             bkg_chain = ROOT.TChain()
@@ -293,49 +327,13 @@ class ClassificationProject(object):
                 signal_chain.AddFile(filename, -1, treename)
             for filename, treename in self.bkg_trees:
                 bkg_chain.AddFile(filename, -1, treename)
-            self.s_train = tree2array(signal_chain,
-                                      branches=self.branches+[self.weight_expr]+self.identifiers,
-                                      selection=self.selection,
-                                      start=0, step=self.step_signal)
-            self.b_train = tree2array(bkg_chain,
-                                      branches=self.branches+[self.weight_expr]+self.identifiers,
-                                      selection=self.selection,
-                                      start=0, step=self.step_bkg)
-            self.s_test = tree2array(signal_chain,
-                                     branches=self.branches+[self.weight_expr],
-                                     selection=self.selection,
-                                     start=1, step=self.step_signal)
-            self.b_test = tree2array(bkg_chain,
-                                     branches=self.branches+[self.weight_expr],
-                                     selection=self.selection,
-                                     start=1, step=self.step_bkg)
-
-            self.s_eventlist_train = self.s_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
-            self.b_eventlist_train = self.b_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
-            self._dump_training_list()
 
-            # now we don't need the identifiers anymore
-            self.s_train = self.s_train[self.branches+[self.weight_expr]]
-            self.b_train = self.b_train[self.branches+[self.weight_expr]]
-
-            # create x (input), y (target) and w (weights) arrays
-            # the first block will be signals, the second block backgrounds
-            self.x_train = rec2array(self.s_train[self.branches])
-            self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
-            self.x_test = rec2array(self.s_test[self.branches])
-            self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
-            self.w_train = self.s_train[self.weight_expr]
-            self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
-            self.w_test = self.s_test[self.weight_expr]
-            self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))
-
-            self.y_train = np.empty(len(self.x_train))
-            self.y_train[:len(self.s_train)] = 1
-            self.y_train[len(self.s_train):] = 0
-            self.y_test = np.empty(len(self.x_test))
-            self.y_test[:len(self.s_test)] = 1
-            self.y_test[len(self.s_test):] = 0
+            self._append_from_tree(signal_chain, step=self.step_signal, is_train=True, is_signal=True)
+            self._append_from_tree(bkg_chain, step=self.step_bkg, is_train=True, is_signal=False)
+            self._append_from_tree(signal_chain, step=self.step_signal, is_train=False, is_signal=True)
+            self._append_from_tree(bkg_chain, step=self.step_bkg, is_train=False, is_signal=False)
 
+            self._dump_training_list()
             self._dump_to_hdf5(*self.dataset_names_tree)
 
         self.data_loaded = True
-- 
GitLab