diff --git a/toolkit.py b/toolkit.py index 3e75d2ccd65697af669af7016abf7905bfb769ac..67804358efbb463f23134d8bafa5f2205b01cb1a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1024,7 +1024,9 @@ class ClassificationProject(object): def write_friend_tree(self, score_name, source_filename, source_treename, target_filename, target_treename, - batch_size=100000, score_mode=None): + batch_size=100000, + score_mode=None, + fixed_params=None): f = ROOT.TFile.Open(source_filename) tree = f.Get(source_treename) entries = tree.GetEntries() @@ -1037,8 +1039,11 @@ class ClassificationProject(object): x_from_tree = tree2array(tree, branches=self.branches+self.identifiers, start=start, stop=start+batch_size) + # for parametrized classifiers + if fixed_params is not None: + for param_name, value in fixed_params.items(): + x_from_tree[param_name] = value x_eval = rec2array(x_from_tree[self.branches]) - if len(self.identifiers) > 0: # create list of booleans that indicate which events where used for training df_identifiers = pd.DataFrame(x_from_tree[self.identifiers])