Skip to content
Snippets Groups Projects
Commit 013cb179 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

planing in 1D implemented

parent 20f895ff
No related branches found
No related tags found
No related merge requests found
...@@ -187,10 +187,10 @@ class ClassificationProject(object): ...@@ -187,10 +187,10 @@ class ClassificationProject(object):
# Datasets that are stored to (and dynamically loaded from) hdf5 # Datasets that are stored to (and dynamically loaded from) hdf5
dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test", "planing_array"] dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test"]
# Datasets that are retrieved from ROOT trees the first time # Datasets that are retrieved from ROOT trees the first time
dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "planing_array"] dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
def __init__(self, name, *args, **kwargs): def __init__(self, name, *args, **kwargs):
if len(args) < 1 and len(kwargs) < 1: if len(args) < 1 and len(kwargs) < 1:
...@@ -329,6 +329,10 @@ class ClassificationProject(object): ...@@ -329,6 +329,10 @@ class ClassificationProject(object):
self.planing_var, self.planing_bins, self.planing_range = planing_vars self.planing_var, self.planing_bins, self.planing_range = planing_vars
if self.planing_var is not None:
self.dataset_names_tree.append("planing_array")
self.dataset_names.append("planing_array")
self.s_train = None self.s_train = None
self.b_train = None self.b_train = None
self.s_test = None self.s_test = None
...@@ -413,11 +417,12 @@ class ClassificationProject(object): ...@@ -413,11 +417,12 @@ class ClassificationProject(object):
for filename, treename in self.bkg_trees: for filename, treename in self.bkg_trees:
bkg_chain.AddFile(filename, -1, treename) bkg_chain.AddFile(filename, -1, treename)
branches = self.branches branches = list(self.branches)
if self.planing_var is not None: if self.planing_var is not None:
branches.append(self.planing_var) branches.append(self.planing_var)
print(branches) # remove duplicates
branches = list(set(branches))
self.s_train = tree2array(signal_chain, self.s_train = tree2array(signal_chain,
branches=branches+[self.weight_expr]+self.identifiers, branches=branches+[self.weight_expr]+self.identifiers,
...@@ -851,7 +856,7 @@ class ClassificationProject(object): ...@@ -851,7 +856,7 @@ class ClassificationProject(object):
it becomes flat for both signal and background (information it becomes flat for both signal and background (information
effectively removed) effectively removed)
""" """
if self._w_train_plane is None and self.planing_array is not None: if self._w_train_plane is None and self.planing_var is not None:
self._w_train_plane = np.empty(len(self.x_train), dtype=float) self._w_train_plane = np.empty(len(self.x_train), dtype=float)
for class_label in [0, 1]: for class_label in [0, 1]:
ar = self.planing_array[self.y_train==class_label] ar = self.planing_array[self.y_train==class_label]
...@@ -864,9 +869,9 @@ class ClassificationProject(object): ...@@ -864,9 +869,9 @@ class ClassificationProject(object):
sfs = 1./hist sfs = 1./hist
sfs[np.isinf(sfs)] = 0 sfs[np.isinf(sfs)] = 0
sfs = np.concatenate([sfs, [0]]) # overflow is reweighted to 0 sfs = np.concatenate([sfs, [0]]) # overflow is reweighted to 0
bin_idx = np.digitize(ar, bins) bin_idx = np.digitize(ar, edges)
bin_inds -= 1 # different convention for digitize and histogram? bin_idx -= 1 # different convention for digitize and histogram?
self._w_train_plane[self.y_train==class_label] = sfs[bin_inds] self._w_train_plane[self.y_train==class_label] = sfs[bin_idx]
return self._w_train_plane return self._w_train_plane
...@@ -882,8 +887,6 @@ class ClassificationProject(object): ...@@ -882,8 +887,6 @@ class ClassificationProject(object):
w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
else: else:
w_train_tot = np.array(self.w_train) w_train_tot = np.array(self.w_train)
if self.normalize_weights:
w_train_tot /= np.mean(w_train_tot)
return w_train_tot return w_train_tot
...@@ -900,6 +903,8 @@ class ClassificationProject(object): ...@@ -900,6 +903,8 @@ class ClassificationProject(object):
self._w_train_tot = self.get_total_weight() self._w_train_tot = self.get_total_weight()
if self.w_train_plane is not None: if self.w_train_plane is not None:
self._w_train_tot *= self._w_train_plane self._w_train_tot *= self._w_train_plane
if self.normalize_weights:
self._w_train_tot /= np.mean(self._w_train_tot)
return self._w_train_tot return self._w_train_tot
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment