From 8be7b2f2a09afd142076fc36c79bef58813b3a70 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 30 Apr 2018 15:11:01 +0200
Subject: [PATCH] trying ...

---
 toolkit.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 46 insertions(+), 8 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 9904ec7..822d71c 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -9,7 +9,7 @@ import logging
 logger = logging.getLogger("KerasROOTClassification")
 logger.addHandler(logging.NullHandler())
 
-from root_numpy import tree2array, rec2array
+from root_numpy import tree2array, rec2array, array2root
 import numpy as np
 import pandas as pd
 import h5py
@@ -418,10 +418,38 @@ class KerasROOTClassification(object):
 
 
 
-    def evaluate(self):
-        pass
-
-    def write_friend_tree(self):
+    def evaluate(self, x_eval):
+        logger.debug("Evaluate score for {}".format(x_eval))
+        x_eval = self.scaler.transform(x_eval)
+        logger.debug("Evaluate for transformed array: {}".format(x_eval))
+        return self.model.predict(x_eval)
+
+
+    def write_friend_tree(self, score_name,
+                          source_filename, source_treename,
+                          target_filename, target_treename,
+                          batch_size=100000):
+        f = ROOT.TFile.Open(source_filename)
+        tree = f.Get(source_treename)
+        entries = tree.GetEntries()
+        if os.path.exists(target_filename):
+            raise IOError("{} already exists, if you want to recreate it, delete it first".format(target_filename))
+        for start in range(0, entries, batch_size):
+            logger.debug("Loading next batch")
+            x_eval = rec2array(tree2array(tree,
+                                          branches=self.branches,
+                                          start=start, stop=start+batch_size))
+            scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)])
+            print(len(scores))
+            print(scores)
+            if start == 0:
+                mode = "recreate"
+            else:
+                mode = "update"
+            logger.debug("Write to root file")
+            array2root(scores, target_filename, treename=target_treename, mode=mode)
+
+    def write_all_friend_trees(self):
         pass
 
 
@@ -561,7 +589,7 @@ if __name__ == "__main__":
 
     logging.basicConfig()
     logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
-    #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
+    logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
 
     filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
 
@@ -582,5 +610,15 @@ if __name__ == "__main__":
     c.load()
     #c.train(epochs=20)
     c.plot_ROC()
-    # c.plot_loss()
-    # c.plot_accuracy()
+    c.plot_loss()
+    c.plot_accuracy()
+
+    c.write_friend_tree("test4_score",
+                        source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
+                        target_filename="friend.root", target_treename="test4_score")
+
+    np.random.seed(1234)
+
+    c.write_friend_tree("test4_score",
+                        source_filename=filename, source_treename="ttbar_NoSys",
+                        target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")
-- 
GitLab