Skip to content
Snippets Groups Projects
Commit 20f895ff authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

trying to implement data planing in 1D

parent 49243434
No related branches found
No related tags found
No related merge requests found
...@@ -179,14 +179,18 @@ class ClassificationProject(object): ...@@ -179,14 +179,18 @@ class ClassificationProject(object):
:param normalize_weights: normalize the weights to mean 1 :param normalize_weights: normalize the weights to mean 1
:param planing_vars: variables in which a binwise reweighting
should be performed, such that the distribution becomes flat
("Data planing"). Pass a tuple of (expr, bins, range)
""" """
# Datasets that are stored to (and dynamically loaded from) hdf5 # Datasets that are stored to (and dynamically loaded from) hdf5
dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test"] dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test", "planing_array"]
# Datasets that are retrieved from ROOT trees the first time # Datasets that are retrieved from ROOT trees the first time
dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"] dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "planing_array"]
def __init__(self, name, *args, **kwargs): def __init__(self, name, *args, **kwargs):
if len(args) < 1 and len(kwargs) < 1: if len(args) < 1 and len(kwargs) < 1:
...@@ -245,7 +249,9 @@ class ClassificationProject(object): ...@@ -245,7 +249,9 @@ class ClassificationProject(object):
loss='binary_crossentropy', loss='binary_crossentropy',
mask_value=None, mask_value=None,
apply_class_weight=True, apply_class_weight=True,
normalize_weights=True): normalize_weights=True,
planing_vars=(None, None, None),
):
self.name = name self.name = name
self.signal_trees = signal_trees self.signal_trees = signal_trees
...@@ -321,6 +327,8 @@ class ClassificationProject(object): ...@@ -321,6 +327,8 @@ class ClassificationProject(object):
self.apply_class_weight = apply_class_weight self.apply_class_weight = apply_class_weight
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
self.planing_var, self.planing_bins, self.planing_range = planing_vars
self.s_train = None self.s_train = None
self.b_train = None self.b_train = None
self.s_test = None self.s_test = None
...@@ -334,10 +342,14 @@ class ClassificationProject(object): ...@@ -334,10 +342,14 @@ class ClassificationProject(object):
self._w_test = None self._w_test = None
self._scores_train = None self._scores_train = None
self._scores_test = None self._scores_test = None
self._planing_array = None
# class weighted training data (divided by mean) # class weighted training data (divided by mean)
self._w_train_tot = None self._w_train_tot = None
# planing weights in case requested
self._w_train_plane = None
self._s_eventlist_train = None self._s_eventlist_train = None
self._b_eventlist_train = None self._b_eventlist_train = None
...@@ -400,20 +412,27 @@ class ClassificationProject(object): ...@@ -400,20 +412,27 @@ class ClassificationProject(object):
signal_chain.AddFile(filename, -1, treename) signal_chain.AddFile(filename, -1, treename)
for filename, treename in self.bkg_trees: for filename, treename in self.bkg_trees:
bkg_chain.AddFile(filename, -1, treename) bkg_chain.AddFile(filename, -1, treename)
branches = self.branches
if self.planing_var is not None:
branches.append(self.planing_var)
print(branches)
self.s_train = tree2array(signal_chain, self.s_train = tree2array(signal_chain,
branches=self.branches+[self.weight_expr]+self.identifiers, branches=branches+[self.weight_expr]+self.identifiers,
selection=self.selection, selection=self.selection,
start=0, step=self.step_signal, stop=self.stop_train) start=0, step=self.step_signal, stop=self.stop_train)
self.b_train = tree2array(bkg_chain, self.b_train = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr]+self.identifiers, branches=branches+[self.weight_expr]+self.identifiers,
selection=self.selection, selection=self.selection,
start=0, step=self.step_bkg, stop=self.stop_train) start=0, step=self.step_bkg, stop=self.stop_train)
self.s_test = tree2array(signal_chain, self.s_test = tree2array(signal_chain,
branches=self.branches+[self.weight_expr], branches=branches+[self.weight_expr],
selection=self.selection, selection=self.selection,
start=1, step=self.step_signal, stop=self.stop_test) start=1, step=self.step_signal, stop=self.stop_test)
self.b_test = tree2array(bkg_chain, self.b_test = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr], branches=branches+[self.weight_expr],
selection=self.selection, selection=self.selection,
start=1, step=self.step_bkg, stop=self.stop_test) start=1, step=self.step_bkg, stop=self.stop_test)
...@@ -426,6 +445,10 @@ class ClassificationProject(object): ...@@ -426,6 +445,10 @@ class ClassificationProject(object):
self.b_eventlist_train = self.b_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers]) self.b_eventlist_train = self.b_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
self._dump_training_list() self._dump_training_list()
# store planing branch
if self.planing_var is not None:
self.planing_array = np.concatenate([self.s_train[self.planing_var], self.b_train[self.planing_var]])
# now we don't need the identifiers anymore # now we don't need the identifiers anymore
self.s_train = self.s_train[self.fields+[self.weight_expr]] self.s_train = self.s_train[self.fields+[self.weight_expr]]
self.b_train = self.b_train[self.fields+[self.weight_expr]] self.b_train = self.b_train[self.fields+[self.weight_expr]]
...@@ -821,6 +844,49 @@ class ClassificationProject(object): ...@@ -821,6 +844,49 @@ class ClassificationProject(object):
self.data_shuffled = True self.data_shuffled = True
@property
def w_train_plane(self):
"""
weight that reweights in a requested distribution, such that
it becomes flat for both signal and background (information
effectively removed)
"""
if self._w_train_plane is None and self.planing_array is not None:
self._w_train_plane = np.empty(len(self.x_train), dtype=float)
for class_label in [0, 1]:
ar = self.planing_array[self.y_train==class_label]
hist, edges = np.histogram(
ar,
bins=self.planing_bins,
range=self.planing_range,
weights=self.get_total_weight()[self.y_train==class_label],
)
sfs = 1./hist
sfs[np.isinf(sfs)] = 0
sfs = np.concatenate([sfs, [0]]) # overflow is reweighted to 0
bin_idx = np.digitize(ar, bins)
bin_inds -= 1 # different convention for digitize and histogram?
self._w_train_plane[self.y_train==class_label] = sfs[bin_inds]
return self._w_train_plane
def get_total_weight(self):
"(sample weight * class weight), divided by mean (for training)"
if not self.balance_dataset:
class_weight = self.class_weight
else:
class_weight = self.balanced_class_weight
if not self.data_loaded:
raise ValueError("Data not loaded! can't calculate total weight")
if self.apply_class_weight:
w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
else:
w_train_tot = np.array(self.w_train)
if self.normalize_weights:
w_train_tot /= np.mean(w_train_tot)
return w_train_tot
@property @property
def w_train_tot(self): def w_train_tot(self):
"(sample weight * class weight), divided by mean" "(sample weight * class weight), divided by mean"
...@@ -831,12 +897,9 @@ class ClassificationProject(object): ...@@ -831,12 +897,9 @@ class ClassificationProject(object):
if not self.data_loaded: if not self.data_loaded:
raise ValueError("Data not loaded! can't calculate total weight") raise ValueError("Data not loaded! can't calculate total weight")
if self._w_train_tot is None: if self._w_train_tot is None:
if self.apply_class_weight: self._w_train_tot = self.get_total_weight()
self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] if self.w_train_plane is not None:
else: self._w_train_tot *= self._w_train_plane
self._w_train_tot = np.array(self.w_train)
if self.normalize_weights:
self._w_train_tot /= np.mean(self._w_train_tot)
return self._w_train_tot return self._w_train_tot
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment