From c5c9734a852d1ff6868ec5291d2eeddafa1e5f0c Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 5 Jun 2018 14:26:36 +0200 Subject: [PATCH] Identifiers option now optional --- toolkit.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/toolkit.py b/toolkit.py index 96fb4c9..fa96e41 100755 --- a/toolkit.py +++ b/toolkit.py @@ -151,7 +151,8 @@ class ClassificationProject(object): def _init_from_args(self, name, - signal_trees, bkg_trees, branches, weight_expr, identifiers, + signal_trees, bkg_trees, branches, weight_expr, + identifiers=None, selection=None, layers=3, nodes=64, @@ -178,6 +179,8 @@ class ClassificationProject(object): self.branches = branches self.weight_expr = weight_expr self.selection = selection + if identifiers is None: + identifiers = [] self.identifiers = identifiers self.layers = layers self.nodes = nodes @@ -730,12 +733,15 @@ class ClassificationProject(object): start=start, stop=start+batch_size) x_eval = rec2array(x_from_tree[self.branches]) - # create list of booleans that indicate which events where used for training - df_identifiers = pd.DataFrame(x_from_tree[self.identifiers]) - total_train_list = self.s_eventlist_train - total_train_list = np.concatenate((total_train_list, self.b_eventlist_train)) - merged = df_identifiers.merge(pd.DataFrame(total_train_list), on=tuple(self.identifiers), indicator=True, how="left") - is_train = np.array(merged["_merge"] == "both") + if len(self.identifiers) > 0: + # create list of booleans that indicate which events where used for training + df_identifiers = pd.DataFrame(x_from_tree[self.identifiers]) + total_train_list = self.s_eventlist_train + total_train_list = np.concatenate((total_train_list, self.b_eventlist_train)) + merged = df_identifiers.merge(pd.DataFrame(total_train_list), on=tuple(self.identifiers), indicator=True, how="left") + is_train = np.array(merged["_merge"] == "both") + else: + is_train = np.zeros(len(x_eval)) # join scores and is_train array scores = self.evaluate(x_eval).reshape(-1) -- GitLab