diff --git a/toolkit.py b/toolkit.py index 9904ec7be590293503af6a515d64c599df07c115..822d71c52e4fd794fe0f2739ffb28822f32b5d43 100755 --- a/toolkit.py +++ b/toolkit.py @@ -9,7 +9,7 @@ import logging logger = logging.getLogger("KerasROOTClassification") logger.addHandler(logging.NullHandler()) -from root_numpy import tree2array, rec2array +from root_numpy import tree2array, rec2array, array2root import numpy as np import pandas as pd import h5py @@ -418,10 +418,38 @@ class KerasROOTClassification(object): - def evaluate(self): - pass - - def write_friend_tree(self): + def evaluate(self, x_eval): + logger.debug("Evaluate score for {}".format(x_eval)) + x_eval = self.scaler.transform(x_eval) + logger.debug("Evaluate for transformed array: {}".format(x_eval)) + return self.model.predict(x_eval) + + + def write_friend_tree(self, score_name, + source_filename, source_treename, + target_filename, target_treename, + batch_size=100000): + f = ROOT.TFile.Open(source_filename) + tree = f.Get(source_treename) + entries = tree.GetEntries() + if os.path.exists(target_filename): + raise IOError("{} already exists, if you want to recreate it, delete it first".format(target_filename)) + for start in range(0, entries, batch_size): + logger.debug("Loading next batch") + x_eval = rec2array(tree2array(tree, + branches=self.branches, + start=start, stop=start+batch_size)) + scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)]) + print(len(scores)) + print(scores) + if start == 0: + mode = "recreate" + else: + mode = "update" + logger.debug("Write to root file") + array2root(scores, target_filename, treename=target_treename, mode=mode) + + def write_all_friend_trees(self): pass @@ -561,7 +589,7 @@ if __name__ == "__main__": logging.basicConfig() logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) - #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG) + logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG) filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root" @@ -582,5 +610,15 @@ if __name__ == "__main__": c.load() #c.train(epochs=20) c.plot_ROC() - # c.plot_loss() - # c.plot_accuracy() + c.plot_loss() + c.plot_accuracy() + + c.write_friend_tree("test4_score", + source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys", + target_filename="friend.root", target_treename="test4_score") + + np.random.seed(1234) + + c.write_friend_tree("test4_score", + source_filename=filename, source_treename="ttbar_NoSys", + target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")