From 0b4a434411ee6c854928626c24601116c3563ca5 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 1 May 2018 15:14:31 +0200
Subject: [PATCH] Write information which events where used for training into
 friend tree

---
 toolkit.py | 24 ++++++++++++++++++------
 1 file changed, 18 insertions(+), 6 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index e5fccaf..15ef8b6 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -463,17 +463,29 @@ class KerasROOTClassification(object):
         for start in range(0, entries, batch_size):
             logger.info("Evaluating score for entry {}/{}".format(start, entries))
             logger.debug("Loading next batch")
-            x_eval = rec2array(tree2array(tree,
-                                          branches=self.branches,
-                                          start=start, stop=start+batch_size))
-            logger.debug("Done")
-            scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)])
+            x_from_tree = tree2array(tree,
+                                     branches=self.branches+self.identifiers,
+                                     start=start, stop=start+batch_size)
+            x_eval = rec2array(x_from_tree[self.branches])
+
+            # create list of booleans that indicate which events where used for training
+            df_identifiers = pd.DataFrame(x_from_tree[self.identifiers])
+            total_train_list = self.s_eventlist_train
+            total_train_list = np.concatenate((total_train_list, self.b_eventlist_train))
+            merged = df_identifiers.merge(pd.DataFrame(total_train_list), on=tuple(self.identifiers), indicator=True, how="left")
+            is_train = np.array(merged["_merge"] == "both")
+
+            # join scores and is_train array
+            scores = self.evaluate(x_eval).reshape(-1)
+            friend_df = pd.DataFrame(np.array(scores, dtype=[(score_name, np.float64)]))
+            friend_df["is_train"] = is_train
+            friend_tree = friend_df.to_records()[[score_name, "is_train"]]
             if start == 0:
                 mode = "recreate"
             else:
                 mode = "update"
             logger.debug("Write to root file")
-            array2root(scores, target_filename, treename=target_treename, mode=mode)
+            array2root(friend_tree, target_filename, treename=target_treename, mode=mode)
             logger.debug("Done")
 
     def write_all_friend_trees(self):
-- 
GitLab