From 80673f0e90bddcfd366bb50f19459e1a66b8f427 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Wed, 28 Nov 2018 10:48:59 +0100 Subject: [PATCH] starting adversarial setup for decorrelation --- toolkit.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/toolkit.py b/toolkit.py index d673121..89e3e89 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"] +__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN", "ClassificationProjectDecorr"] from sys import version_info @@ -72,6 +72,8 @@ def load_from_dir(path): project_type = info["project_type"] if project_type == "ClassificationProjectRNN": return ClassificationProjectRNN(path) + elif project_type == "ClassificationProjectDecorr": + return ClassificationProjectDecorr(path) except (KeyError, IOError): pass return ClassificationProject(path) @@ -2014,6 +2016,41 @@ class ClassificationProjectRNN(ClassificationProject): return self.predict(self.get_input_list(x_eval), mode=mode) +class ClassificationProjectDecorr(ClassificationProject): + + def __init__(self, *args, **kwargs): + super(ClassificationProjectDecorr, self).__init__(*args, **kwargs) + self.decorr_binnings = [] + self.decorr_bins = 10 + self._write_info("project_type", "ClassificationProjectDecorr") + + + def load(self, *args, **kwargs): + super(ClassificationProjectDecorr, self).load(*args, **kwargs) + bin_frac = 1./float(self.decorr_bins) + for idx, field_name in enumerate(self.target_fields): + # adversary target is fit as multiclass problem with bin indices + # (self.decorr_bins quantiles) as labels like in arXiv:1703.03507 + self.decorr_binnings.append( + weighted_quantile( + self.y_train[self.l_train==0][:,idx], # bkg only + np.arange(bin_frac, 1+bin_frac, bin_frac), + sample_weight=self.w_train[self.l_train==0] + ) + ) + + + def get_output_list(self, y): + out_list = super(ClassificationProjectDecorr, self).get_output_list(y) + for i, (out, binning) in enumerate( + zip(out_list[1:], self.decorr_binnings) + ): + bin_idx = np.digitize(out, binning) + # include overflow into last bin + bin_idx[bin_idx==len(deciles)] = len(deciles)-1 + out_list[i] = keras.utils.to_categorical(bin_idx) + + if __name__ == "__main__": logging.basicConfig() -- GitLab