From 4ed198f6228ffe9313bb6ef777f94ea8503f3c22 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Mon, 19 Nov 2018 09:46:53 +0100
Subject: [PATCH] fill regression targets optionally

---
 toolkit.py | 48 ++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 38 insertions(+), 10 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 524a66f..dd7651a 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -100,6 +100,8 @@ class ClassificationProject(object):
 
     :param branches: list of branch names or expressions to be used as input values for training
 
+    :param regression_branches: list of branch names to be used as regression targets
+
     :param rename_branches: dictionary that maps branch expressions to names for better readability
 
     :param weight_expr: expression to weight the events in the loss function
@@ -224,6 +226,7 @@ class ClassificationProject(object):
 
     def _init_from_args(self, name,
                         signal_trees, bkg_trees, branches, weight_expr,
+                        regression_branches=None,
                         rename_branches=None,
                         project_dir=None,
                         data_dir=None,
@@ -270,6 +273,9 @@ class ClassificationProject(object):
         if rename_branches is None:
             rename_branches = {}
         self.rename_branches = rename_branches
+        if regression_branches is None:
+            regression_branches = []
+        self.regression_branches = regression_branches
         self.weight_expr = weight_expr
         self.selection = selection
 
@@ -382,6 +388,7 @@ class ClassificationProject(object):
         self.is_training = False
 
         self._fields = None
+        self._target_fields = None
 
 
     @property
@@ -394,6 +401,16 @@ class ClassificationProject(object):
         return self._fields
 
 
+    @property
+    def target_fields(self):
+        "Renamed branch expressions for regression targets"
+        if self._target_fields is None:
+            self._target_fields = []
+            for branch_expr in self.regression_branches:
+                self._target_fields.append(self.rename_branches.get(branch_expr, branch_expr))
+        return self._target_fields
+
+
     def rename_fields(self, ar):
         "Rename fields of structured array"
         fields = list(ar.dtype.names)
@@ -424,19 +441,19 @@ class ClassificationProject(object):
             for filename, treename in self.bkg_trees:
                 bkg_chain.AddFile(filename, -1, treename)
             self.s_train = tree2array(signal_chain,
-                                      branches=self.branches+[self.weight_expr]+self.identifiers,
+                                      branches=self.branches+self.regression_branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
                                       start=0, step=self.step_signal, stop=self.stop_train)
             self.b_train = tree2array(bkg_chain,
-                                      branches=self.branches+[self.weight_expr]+self.identifiers,
+                                      branches=self.branches+self.regression_branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
                                       start=0, step=self.step_bkg, stop=self.stop_train)
             self.s_test = tree2array(signal_chain,
-                                     branches=self.branches+[self.weight_expr],
+                                     branches=self.branches+self.regression_branches+[self.weight_expr],
                                      selection=self.selection,
                                      start=1, step=self.step_signal, stop=self.stop_test)
             self.b_test = tree2array(bkg_chain,
-                                     branches=self.branches+[self.weight_expr],
+                                     branches=self.branches+self.regression_branches+[self.weight_expr],
                                      selection=self.selection,
                                      start=1, step=self.step_bkg, stop=self.stop_test)
 
@@ -459,9 +476,21 @@ class ClassificationProject(object):
             self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.fields])))
             self.w_train = self.s_train[self.weight_expr]
             self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
-            self.y_train = np.empty(len(self.x_train), dtype=np.bool)
-            self.y_train[:len(self.s_train)] = 1
-            self.y_train[len(self.s_train):] = 0
+
+            def fill_target(x, s, b):
+                if not self.target_fields:
+                    y = np.empty(len(x), dtype=np.bool)
+                    y[:len(s)] = 1
+                    y[len(s):] = 0
+                else:
+                    y = np.empty((len(x), 1+len(self.target_fields)), dtype=np.float)
+                    y[:len(s),0] = 1
+                    y[len(s):,0] = 0
+                    y[:len(s),1:] = s[self.target_fields]
+                    y[len(s):,1:] = b[self.target_fields]
+                return y
+
+            self.y_train = fill_target(self.x_train, self.s_train, self.b_train)
             self.b_train = None
             self.s_train = None
 
@@ -469,9 +498,8 @@ class ClassificationProject(object):
             self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.fields])))
             self.w_test = self.s_test[self.weight_expr]
             self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))
-            self.y_test = np.empty(len(self.x_test), dtype=np.bool)
-            self.y_test[:len(self.s_test)] = 1
-            self.y_test[len(self.s_test):] = 0
+
+            self.y_test = fill_target(self.x_test, self.s_test, self.b_test)
             self.b_test = None
             self.s_test = None
 
-- 
GitLab