diff --git a/toolkit.py b/toolkit.py index b6ec1f929a8eab5235c7893a174384ba2fff84ff..3e75d2ccd65697af669af7016abf7905bfb769ac 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1024,7 +1024,7 @@ class ClassificationProject(object): def write_friend_tree(self, score_name, source_filename, source_treename, target_filename, target_treename, - batch_size=100000): + batch_size=100000, score_mode=None): f = ROOT.TFile.Open(source_filename) tree = f.Get(source_treename) entries = tree.GetEntries() @@ -1050,7 +1050,7 @@ class ClassificationProject(object): is_train = np.zeros(len(x_eval)) # join scores and is_train array - scores = self.evaluate(x_eval).reshape(-1) + scores = self.evaluate(x_eval, mode=score_mode).reshape(-1) friend_df = pd.DataFrame(np.array(scores, dtype=[(score_name, np.float64)])) friend_df[score_name+"_is_train"] = is_train friend_tree = friend_df.to_records()[[score_name, score_name+"_is_train"]]