From 4ed198f6228ffe9313bb6ef777f94ea8503f3c22 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Mon, 19 Nov 2018 09:46:53 +0100 Subject: [PATCH] fill regression targets optionally --- toolkit.py | 48 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/toolkit.py b/toolkit.py index 524a66f..dd7651a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -100,6 +100,8 @@ class ClassificationProject(object): :param branches: list of branch names or expressions to be used as input values for training + :param regression_branches: list of branch names to be used as regression targets + :param rename_branches: dictionary that maps branch expressions to names for better readability :param weight_expr: expression to weight the events in the loss function @@ -224,6 +226,7 @@ class ClassificationProject(object): def _init_from_args(self, name, signal_trees, bkg_trees, branches, weight_expr, + regression_branches=None, rename_branches=None, project_dir=None, data_dir=None, @@ -270,6 +273,9 @@ class ClassificationProject(object): if rename_branches is None: rename_branches = {} self.rename_branches = rename_branches + if regression_branches is None: + regression_branches = [] + self.regression_branches = regression_branches self.weight_expr = weight_expr self.selection = selection @@ -382,6 +388,7 @@ class ClassificationProject(object): self.is_training = False self._fields = None + self._target_fields = None @property @@ -394,6 +401,16 @@ class ClassificationProject(object): return self._fields + @property + def target_fields(self): + "Renamed branch expressions for regression targets" + if self._target_fields is None: + self._target_fields = [] + for branch_expr in self.regression_branches: + self._target_fields.append(self.rename_branches.get(branch_expr, branch_expr)) + return self._target_fields + + def rename_fields(self, ar): "Rename fields of structured array" fields = list(ar.dtype.names) @@ -424,19 +441,19 @@ class ClassificationProject(object): for filename, treename in self.bkg_trees: bkg_chain.AddFile(filename, -1, treename) self.s_train = tree2array(signal_chain, - branches=self.branches+[self.weight_expr]+self.identifiers, + branches=self.branches+self.regression_branches+[self.weight_expr]+self.identifiers, selection=self.selection, start=0, step=self.step_signal, stop=self.stop_train) self.b_train = tree2array(bkg_chain, - branches=self.branches+[self.weight_expr]+self.identifiers, + branches=self.branches+self.regression_branches+[self.weight_expr]+self.identifiers, selection=self.selection, start=0, step=self.step_bkg, stop=self.stop_train) self.s_test = tree2array(signal_chain, - branches=self.branches+[self.weight_expr], + branches=self.branches+self.regression_branches+[self.weight_expr], selection=self.selection, start=1, step=self.step_signal, stop=self.stop_test) self.b_test = tree2array(bkg_chain, - branches=self.branches+[self.weight_expr], + branches=self.branches+self.regression_branches+[self.weight_expr], selection=self.selection, start=1, step=self.step_bkg, stop=self.stop_test) @@ -459,9 +476,21 @@ class ClassificationProject(object): self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.fields]))) self.w_train = self.s_train[self.weight_expr] self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr])) - self.y_train = np.empty(len(self.x_train), dtype=np.bool) - self.y_train[:len(self.s_train)] = 1 - self.y_train[len(self.s_train):] = 0 + + def fill_target(x, s, b): + if not self.target_fields: + y = np.empty(len(x), dtype=np.bool) + y[:len(s)] = 1 + y[len(s):] = 0 + else: + y = np.empty((len(x), 1+len(self.target_fields)), dtype=np.float) + y[:len(s),0] = 1 + y[len(s):,0] = 0 + y[:len(s),1:] = s[self.target_fields] + y[len(s):,1:] = b[self.target_fields] + return y + + self.y_train = fill_target(self.x_train, self.s_train, self.b_train) self.b_train = None self.s_train = None @@ -469,9 +498,8 @@ class ClassificationProject(object): self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.fields]))) self.w_test = self.s_test[self.weight_expr] self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr])) - self.y_test = np.empty(len(self.x_test), dtype=np.bool) - self.y_test[:len(self.s_test)] = 1 - self.y_test[len(self.s_test):] = 0 + + self.y_test = fill_target(self.x_test, self.s_test, self.b_test) self.b_test = None self.s_test = None -- GitLab