From beea1cfe45d90cd1ffe294eec8086d92e9c29e8c Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Fri, 10 Aug 2018 15:10:26 +0200
Subject: [PATCH] Experimental support for initialising from pandas DataFrame
 (memory intense ...)

---
 toolkit.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 110 insertions(+), 6 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index a318c21..875cf44 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-__all__ = ["ClassificationProject"]
+__all__ = ["ClassificationProject", "ClassificationProjectDataFrame"]
 
 from sys import version_info
 
@@ -1171,9 +1171,12 @@ class ClassificationProject(object):
             categories=["background", "signal"]
         )
         for identifier in self.identifiers:
-            df[identifier] = np.concatenate([self.s_eventlist_train[identifier],
-                                             self.b_eventlist_train[identifier],
-                                             -1*np.ones(len(self.x_test), dtype="i8")])
+            try:
+                df[identifier] = np.concatenate([self.s_eventlist_train[identifier],
+                                                 self.b_eventlist_train[identifier],
+                                                 -1*np.ones(len(self.x_test), dtype="i8")])
+            except IOError:
+                logger.warning("Can't find eventlist - DataFrame won't contain identifiers")
         df["is_train"] = np.concatenate([np.ones(len(self.x_train), dtype=np.bool),
                                          np.zeros(len(self.x_test), dtype=np.bool)])
         return df
@@ -1204,15 +1207,116 @@ class ClassificationProjectDataFrame(ClassificationProject):
     """
 
     def __init__(self,
+                 name,
+                 df,
                  input_columns,
                  weight_column="weights",
                  label_column="labels",
                  signal_label="signal",
                  background_label="background",
                  split_mode="split_column",
-                 split_colurm="is_train",
+                 split_column="is_train",
                  **kwargs):
-        pass
+
+        self.df = df
+        self.input_columns = input_columns
+        self.weight_column = weight_column
+        self.label_column = label_column
+        self.signal_label = signal_label
+        self.background_label = background_label
+        if split_mode != "split_column":
+            raise NotImplementedError("'split_column' is the only currently supported split mode")
+        self.split_mode = split_mode
+        self.split_column = split_column
+        super(ClassificationProjectDataFrame, self).__init__(name,
+                                                             signal_trees=[], bkg_trees=[], branches=[], weight_expr="1",
+                                                             **kwargs)
+        self._x_train = None
+        self._x_test = None
+        self._y_train = None
+        self._y_test = None
+        self._w_train = None
+        self._w_test = None
+
+    @property
+    def x_train(self):
+        if self._x_train is None:
+            self._x_train = self.df[self.df[self.split_column]][self.input_columns].values
+        return self._x_train
+
+    @x_train.setter
+    def x_train(self, value):
+        self._x_train = value
+
+    @property
+    def x_test(self):
+        if self._x_test is None:
+            self._x_test = self.df[~self.df[self.split_column]][self.input_columns].values
+        return self._x_test
+
+    @x_test.setter
+    def x_test(self, value):
+        self._x_test = value
+
+    @property
+    def y_train(self):
+        if self._y_train is None:
+            self._y_train = (self.df[self.df[self.split_column]][self.label_column] == self.signal_label).values
+        return self._y_train
+
+    @y_train.setter
+    def y_train(self, value):
+        self._y_train = value
+
+    @property
+    def y_test(self):
+        if self._y_test is None:
+            self._y_test = (self.df[~self.df[self.split_column]][self.label_column] == self.signal_label).values
+        return self._y_test
+
+    @y_test.setter
+    def y_test(self, value):
+        self._y_test = value
+
+    @property
+    def w_train(self):
+        if self._w_train is None:
+            self._w_train = self.df[self.df[self.split_column]][self.weight_column].values
+        return self._w_train
+
+    @w_train.setter
+    def w_train(self, value):
+        self._w_train = value
+
+    @property
+    def w_test(self):
+        if self._w_test is None:
+            self._w_test = self.df[~self.df[self.split_column]][self.weight_column].values
+        return self._w_test
+
+    @w_test.setter
+    def w_test(self, value):
+        self._w_test = value
+
+    @property
+    def fields(self):
+        return self.input_columns
+
+
+    def load(self, reload=False):
+
+        if reload:
+            self.data_loaded = False
+            self.data_transformed = False
+            self._x_train = None
+            self._x_test = None
+            self._y_train = None
+            self._y_test = None
+            self._w_train = None
+            self._w_test = None
+
+        if not self.data_transformed:
+            self._transform_data()
 
 
 if __name__ == "__main__":
-- 
GitLab