Skip to content
Snippets Groups Projects
Commit a592516c authored by Nikolai's avatar Nikolai
Browse files

try with more appending/replacing

parent 83c2f0d9
Branches dev-memory
No related tags found
No related merge requests found
......@@ -159,7 +159,7 @@ class ClassificationProject(object):
else:
# otherwise initialise new project
self._init_from_args(name, *args, **kwargs)
with open(os.path.join(self.project_dir, "options.pickle"), "w") as of:
with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of:
pickle.dump(dict(args=args, kwargs=kwargs), of)
......@@ -275,6 +275,38 @@ class ClassificationProject(object):
self.is_training = False
def _append_from_tree(self, chain, step, is_train, is_signal):
ar = tree2array(chain,
branches=self.branches+[self.weight_expr]+self.identifiers,
selection=self.selection,
start=0, step=step)
if is_train:
eventlist = ar[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
if is_signal:
self.s_eventlist_train = eventlist
else:
self.b_eventlist_train = eventlist
# now we don't need the identifiers anymore
ar = ar[self.branches+[self.weight_expr]]
trainstring = "train" if is_train else "test"
signal_fun = np.ones if is_signal else np.zeros
setattr(self, "w_"+trainstring, np.concatenate([getattr(self, "w_"+trainstring), ar[self.weight_expr]]))
setattr(self, "x_"+trainstring, np.concatenate([getattr(self, "x_"+trainstring), rec2array(ar[self.branches])]))
setattr(self, "y_"+trainstring, np.concatenate([getattr(self, "y_"+trainstring), signal_fun(len(ar))]))
del ar
def _reset_data(self):
for dataset_name in self.dataset_names:
setattr(self, dataset_name, np.array([]))
self.x_train = self.x_train.reshape(0, len(self.branches))
self.x_test = self.x_test.reshape(0, len(self.branches))
def _load_data(self):
try:
......@@ -286,6 +318,8 @@ class ClassificationProject(object):
logger.info("Couldn't load all datasets - reading from ROOT trees")
self._reset_data()
# Read signal and background trees into structured numpy arrays
signal_chain = ROOT.TChain()
bkg_chain = ROOT.TChain()
......@@ -293,49 +327,13 @@ class ClassificationProject(object):
signal_chain.AddFile(filename, -1, treename)
for filename, treename in self.bkg_trees:
bkg_chain.AddFile(filename, -1, treename)
self.s_train = tree2array(signal_chain,
branches=self.branches+[self.weight_expr]+self.identifiers,
selection=self.selection,
start=0, step=self.step_signal)
self.b_train = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr]+self.identifiers,
selection=self.selection,
start=0, step=self.step_bkg)
self.s_test = tree2array(signal_chain,
branches=self.branches+[self.weight_expr],
selection=self.selection,
start=1, step=self.step_signal)
self.b_test = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr],
selection=self.selection,
start=1, step=self.step_bkg)
self.s_eventlist_train = self.s_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
self.b_eventlist_train = self.b_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
self._dump_training_list()
# now we don't need the identifiers anymore
self.s_train = self.s_train[self.branches+[self.weight_expr]]
self.b_train = self.b_train[self.branches+[self.weight_expr]]
# create x (input), y (target) and w (weights) arrays
# the first block will be signals, the second block backgrounds
self.x_train = rec2array(self.s_train[self.branches])
self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
self.x_test = rec2array(self.s_test[self.branches])
self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
self.w_train = self.s_train[self.weight_expr]
self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
self.w_test = self.s_test[self.weight_expr]
self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))
self.y_train = np.empty(len(self.x_train))
self.y_train[:len(self.s_train)] = 1
self.y_train[len(self.s_train):] = 0
self.y_test = np.empty(len(self.x_test))
self.y_test[:len(self.s_test)] = 1
self.y_test[len(self.s_test):] = 0
self._append_from_tree(signal_chain, step=self.step_signal, is_train=True, is_signal=True)
self._append_from_tree(bkg_chain, step=self.step_bkg, is_train=True, is_signal=False)
self._append_from_tree(signal_chain, step=self.step_signal, is_train=False, is_signal=True)
self._append_from_tree(bkg_chain, step=self.step_bkg, is_train=False, is_signal=False)
self._dump_training_list()
self._dump_to_hdf5(*self.dataset_names_tree)
self.data_loaded = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment