Skip to content
Snippets Groups Projects
Commit 8be7b2f2 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

trying ...

parent 431c04dd
Branches dev-friend
No related tags found
No related merge requests found
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
logger = logging.getLogger("KerasROOTClassification") logger = logging.getLogger("KerasROOTClassification")
logger.addHandler(logging.NullHandler()) logger.addHandler(logging.NullHandler())
from root_numpy import tree2array, rec2array from root_numpy import tree2array, rec2array, array2root
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import h5py import h5py
...@@ -418,10 +418,38 @@ class KerasROOTClassification(object): ...@@ -418,10 +418,38 @@ class KerasROOTClassification(object):
def evaluate(self): def evaluate(self, x_eval):
pass logger.debug("Evaluate score for {}".format(x_eval))
x_eval = self.scaler.transform(x_eval)
def write_friend_tree(self): logger.debug("Evaluate for transformed array: {}".format(x_eval))
return self.model.predict(x_eval)
def write_friend_tree(self, score_name,
source_filename, source_treename,
target_filename, target_treename,
batch_size=100000):
f = ROOT.TFile.Open(source_filename)
tree = f.Get(source_treename)
entries = tree.GetEntries()
if os.path.exists(target_filename):
raise IOError("{} already exists, if you want to recreate it, delete it first".format(target_filename))
for start in range(0, entries, batch_size):
logger.debug("Loading next batch")
x_eval = rec2array(tree2array(tree,
branches=self.branches,
start=start, stop=start+batch_size))
scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)])
print(len(scores))
print(scores)
if start == 0:
mode = "recreate"
else:
mode = "update"
logger.debug("Write to root file")
array2root(scores, target_filename, treename=target_treename, mode=mode)
def write_all_friend_trees(self):
pass pass
...@@ -561,7 +589,7 @@ if __name__ == "__main__": ...@@ -561,7 +589,7 @@ if __name__ == "__main__":
logging.basicConfig() logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
#logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG) logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root" filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
...@@ -582,5 +610,15 @@ if __name__ == "__main__": ...@@ -582,5 +610,15 @@ if __name__ == "__main__":
c.load() c.load()
#c.train(epochs=20) #c.train(epochs=20)
c.plot_ROC() c.plot_ROC()
# c.plot_loss() c.plot_loss()
# c.plot_accuracy() c.plot_accuracy()
c.write_friend_tree("test4_score",
source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
target_filename="friend.root", target_treename="test4_score")
np.random.seed(1234)
c.write_friend_tree("test4_score",
source_filename=filename, source_treename="ttbar_NoSys",
target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment