Skip to content
Snippets Groups Projects
Commit 8be7b2f2 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

trying ...

parent 431c04dd
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ import logging
logger = logging.getLogger("KerasROOTClassification")
logger.addHandler(logging.NullHandler())
from root_numpy import tree2array, rec2array
from root_numpy import tree2array, rec2array, array2root
import numpy as np
import pandas as pd
import h5py
......@@ -418,10 +418,38 @@ class KerasROOTClassification(object):
def evaluate(self):
pass
def write_friend_tree(self):
def evaluate(self, x_eval):
logger.debug("Evaluate score for {}".format(x_eval))
x_eval = self.scaler.transform(x_eval)
logger.debug("Evaluate for transformed array: {}".format(x_eval))
return self.model.predict(x_eval)
def write_friend_tree(self, score_name,
source_filename, source_treename,
target_filename, target_treename,
batch_size=100000):
f = ROOT.TFile.Open(source_filename)
tree = f.Get(source_treename)
entries = tree.GetEntries()
if os.path.exists(target_filename):
raise IOError("{} already exists, if you want to recreate it, delete it first".format(target_filename))
for start in range(0, entries, batch_size):
logger.debug("Loading next batch")
x_eval = rec2array(tree2array(tree,
branches=self.branches,
start=start, stop=start+batch_size))
scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)])
print(len(scores))
print(scores)
if start == 0:
mode = "recreate"
else:
mode = "update"
logger.debug("Write to root file")
array2root(scores, target_filename, treename=target_treename, mode=mode)
def write_all_friend_trees(self):
pass
......@@ -561,7 +589,7 @@ if __name__ == "__main__":
logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
#logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
......@@ -582,5 +610,15 @@ if __name__ == "__main__":
c.load()
#c.train(epochs=20)
c.plot_ROC()
# c.plot_loss()
# c.plot_accuracy()
c.plot_loss()
c.plot_accuracy()
c.write_friend_tree("test4_score",
source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
target_filename="friend.root", target_treename="test4_score")
np.random.seed(1234)
c.write_friend_tree("test4_score",
source_filename=filename, source_treename="ttbar_NoSys",
target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment