Skip to content
Snippets Groups Projects
Commit 0b4a4344 authored by Nikolai's avatar Nikolai
Browse files

Write information which events where used for training into friend tree

parent 669306ab
No related branches found
No related tags found
No related merge requests found
...@@ -463,17 +463,29 @@ class KerasROOTClassification(object): ...@@ -463,17 +463,29 @@ class KerasROOTClassification(object):
for start in range(0, entries, batch_size): for start in range(0, entries, batch_size):
logger.info("Evaluating score for entry {}/{}".format(start, entries)) logger.info("Evaluating score for entry {}/{}".format(start, entries))
logger.debug("Loading next batch") logger.debug("Loading next batch")
x_eval = rec2array(tree2array(tree, x_from_tree = tree2array(tree,
branches=self.branches, branches=self.branches+self.identifiers,
start=start, stop=start+batch_size)) start=start, stop=start+batch_size)
logger.debug("Done") x_eval = rec2array(x_from_tree[self.branches])
scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)])
# create list of booleans that indicate which events where used for training
df_identifiers = pd.DataFrame(x_from_tree[self.identifiers])
total_train_list = self.s_eventlist_train
total_train_list = np.concatenate((total_train_list, self.b_eventlist_train))
merged = df_identifiers.merge(pd.DataFrame(total_train_list), on=tuple(self.identifiers), indicator=True, how="left")
is_train = np.array(merged["_merge"] == "both")
# join scores and is_train array
scores = self.evaluate(x_eval).reshape(-1)
friend_df = pd.DataFrame(np.array(scores, dtype=[(score_name, np.float64)]))
friend_df["is_train"] = is_train
friend_tree = friend_df.to_records()[[score_name, "is_train"]]
if start == 0: if start == 0:
mode = "recreate" mode = "recreate"
else: else:
mode = "update" mode = "update"
logger.debug("Write to root file") logger.debug("Write to root file")
array2root(scores, target_filename, treename=target_treename, mode=mode) array2root(friend_tree, target_filename, treename=target_treename, mode=mode)
logger.debug("Done") logger.debug("Done")
def write_all_friend_trees(self): def write_all_friend_trees(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment