Skip to content
Snippets Groups Projects
Commit 669306ab authored by Nikolai's avatar Nikolai
Browse files

read and save training list

parent 8be7b2f2
No related branches found
No related tags found
No related merge requests found
......@@ -112,8 +112,8 @@ class KerasROOTClassification(object):
self._scores_train = None
self._scores_test = None
self.s_eventlist_train = None
self.b_eventlist_train = None
self._s_eventlist_train = None
self._b_eventlist_train = None
self._scaler = None
self._class_weight = None
......@@ -164,9 +164,9 @@ class KerasROOTClassification(object):
selection=self.selection,
start=1, step=self.step_bkg)
self._dump_training_list()
self.s_eventlist_train = self.s_train[self.identifiers]
self.b_eventlist_train = self.b_train[self.identifiers]
self._dump_training_list()
# now we don't need the identifiers anymore
self.s_train = self.s_train[self.branches+[self.weight_expr]]
......@@ -196,11 +196,37 @@ class KerasROOTClassification(object):
def _dump_training_list(self):
s_eventlist = pd.DataFrame(self.s_train[self.identifiers])
b_eventlist = pd.DataFrame(self.b_train[self.identifiers])
s_eventlist_df = pd.DataFrame(self.s_eventlist_train)
b_eventlist_df = pd.DataFrame(self.b_eventlist_train)
s_eventlist_df.to_csv(os.path.join(self.project_dir, "s_eventlist_train.csv"))
b_eventlist_df.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))
@property
def s_eventlist_train(self):
if self._s_eventlist_train is None:
df = pd.read_csv(os.path.join(self.project_dir, "s_eventlist_train.csv"))
self._s_eventlist_train = df.to_records()[self.identifiers]
return self._s_eventlist_train
@s_eventlist_train.setter
def s_eventlist_train(self, value):
self._s_eventlist_train = value
@property
def b_eventlist_train(self):
if self._b_eventlist_train is None:
df = pd.read_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))
self._b_eventlist_train = df.to_records()[self.identifiers]
return self._b_eventlist_train
s_eventlist.to_csv(os.path.join(self.project_dir, "s_eventlist_train.csv"))
s_eventlist.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))
@b_eventlist_train.setter
def b_eventlist_train(self, value):
self._b_eventlist_train = value
def _dump_to_hdf5(self, *dataset_names):
......@@ -435,19 +461,20 @@ class KerasROOTClassification(object):
if os.path.exists(target_filename):
raise IOError("{} already exists, if you want to recreate it, delete it first".format(target_filename))
for start in range(0, entries, batch_size):
logger.info("Evaluating score for entry {}/{}".format(start, entries))
logger.debug("Loading next batch")
x_eval = rec2array(tree2array(tree,
branches=self.branches,
start=start, stop=start+batch_size))
logger.debug("Done")
scores = np.array(self.evaluate(x_eval), dtype=[(score_name, np.float64)])
print(len(scores))
print(scores)
if start == 0:
mode = "recreate"
else:
mode = "update"
logger.debug("Write to root file")
array2root(scores, target_filename, treename=target_treename, mode=mode)
logger.debug("Done")
def write_all_friend_trees(self):
pass
......@@ -589,7 +616,7 @@ if __name__ == "__main__":
logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
#logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment