Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • Eric.Schanet/KerasROOTClassification
  • Nikolai.Hartmann/KerasROOTClassification
2 results
Show changes
Commits on Source (2)
#!/usr/bin/env python
__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"]
__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN", "ClassificationProjectDecorr"]
from sys import version_info
......@@ -72,6 +72,8 @@ def load_from_dir(path):
project_type = info["project_type"]
if project_type == "ClassificationProjectRNN":
return ClassificationProjectRNN(path)
elif project_type == "ClassificationProjectDecorr":
return ClassificationProjectDecorr(path)
except (KeyError, IOError):
pass
return ClassificationProject(path)
......@@ -934,7 +936,7 @@ class ClassificationProject(object):
x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_val_input = self.get_input_list(self.transform(x_val))
y_val_output = self.get_output_list(self.transform_target(y_val))
w_val_list = self.get_weight_list(w_val)
w_val_list = self.get_weight_list(w_val, y_val)
return x_val_input, y_val_output, w_val_list
......@@ -945,7 +947,7 @@ class ClassificationProject(object):
x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_train_input = self.get_input_list(self.transform(x_train))
y_train_output = self.get_output_list(self.transform_target(y_train))
w_train_list = self.get_weight_list(w_train)
w_train_list = self.get_weight_list(w_train, y_train)
return x_train_input, y_train_output, w_train_list
......@@ -989,7 +991,7 @@ class ClassificationProject(object):
return np.hsplit(y, len(self.target_fields)+1)
def get_weight_list(self, w):
def get_weight_list(self, w, y=None):
"Repeat weight n times for regression targets"
if not self.target_fields:
return w
......@@ -1016,7 +1018,7 @@ class ClassificationProject(object):
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
x_input = self.get_input_list(self.transform(x_batch))
y_output = self.get_output_list(self.transform_target(y_batch))
w_list = self.get_weight_list(w_batch)
w_list = self.get_weight_list(w_batch, y_batch)
yield (x_input, y_output, w_list)
......@@ -2014,6 +2016,114 @@ class ClassificationProjectRNN(ClassificationProject):
return self.predict(self.get_input_list(x_eval), mode=mode)
class ClassificationProjectDecorr(ClassificationProject):
def __init__(self, *args, **kwargs):
super(ClassificationProjectDecorr, self).__init__(*args, **kwargs)
self.decorr_binnings = []
self.decorr_bins = 10
self._write_info("project_type", "ClassificationProjectDecorr")
self._class_layers = None
self._adv_hidden_layers = None
self._adv_target_layers = None
def load(self, *args, **kwargs):
super(ClassificationProjectDecorr, self).load(*args, **kwargs)
bin_frac = 1./float(self.decorr_bins)
print(bin_frac)
print(np.arange(bin_frac, 1+bin_frac, bin_frac))
for idx, field_name in enumerate(self.target_fields):
# adversary target is fit as multiclass problem with bin indices
# (self.decorr_bins quantiles) as labels like in arXiv:1703.03507
self.decorr_binnings.append(
weighted_quantile(
self.y_train[self.l_train==0][:,idx+1], # bkg only
np.arange(bin_frac, 1+bin_frac, bin_frac),
sample_weight=self.w_train[self.l_train==0]
)
)
def get_output_list(self, y):
out_list = super(ClassificationProjectDecorr, self).get_output_list(y)
for i, (out, binning) in enumerate(
zip(out_list[1:], self.decorr_binnings)
):
bin_idx = np.digitize(out, binning)
# include overflow into last bin
bin_idx[bin_idx==len(binning)] = len(binning)-1
out_list[i+1] = keras.utils.to_categorical(bin_idx)
return out_list
def get_weight_list(self, w, y):
w_list = super(ClassificationProjectDecorr, self).get_weight_list(w)
# copy first entry (the others might be references)
w_list[0] = np.array(w_list[0])
for w in w_list[1:]:
# set signal weights to 0 for decorr target
w[y[:,0]==1] = 0.
return w_list
@property
def class_layers(self):
"""
Layers for the classification model
This should be generalised to avoid code duplication with the model
functions of the base classes
"""
if self._class_layers is None:
layers = []
self._class_layers = layers
layers.append(Input((len(self.fields),)))
if self.dropout_input is not None:
layers.append(Dropout(rate=self.dropout_input))
for node_count, dropout_fraction, use_bias in zip(
self.nodes,
self.dropout,
self.use_bias,
):
layers.append(
Dense(
node_count,
activation=self.activation_function,
use_bias=use_bias
)
)
if (dropout_fraction is not None) and (dropout_fraction > 0):
layers.append(Dropout(rate=dropout_fraction))
layers.append(Dense(1, activation=self.activation_function_output))
return self._class_layers
@property
def adv_layers(self):
"""
Layers for the adversary
"""
if self._adv_hidden_layers is None:
self._adv_hidden_layers = []
self._adv_target_layers = []
self._adv_hidden_layers.append(Dense(128, activation="tanh"))
for binning in self.decorr_binnings:
layer = Dense(len(binning), activation="softmax")
self._adv_target_layers.append(layer)
return self._adv_hidden_layers+self._adv_target_layers
@property
def class_input(self):
pass
@property
def model(self):
pass
if __name__ == "__main__":
logging.basicConfig()
......