From 53ba3604e33b7b94aa603833d0e4786bec62616a Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Fri, 30 Nov 2018 15:09:20 +0100 Subject: [PATCH] apply class weights for adversarial targets --- toolkit.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/toolkit.py b/toolkit.py index 61fd512..947968c 100755 --- a/toolkit.py +++ b/toolkit.py @@ -2062,11 +2062,12 @@ class ClassificationProjectDecorr(ClassificationProject): @property def class_weight_target(self): + "to weight the adversarial target to have equal sum of weights per bin" if self._class_weight_target is None: self._class_weight_target = [] for var_i, binning in enumerate(self.decorr_binnings, 1): sumw = self.w_train[self.l_train==0].sum() - class_weight = [ + class_weight = np.array([ sumw/( len(binning) * self.w_train[ @@ -2075,7 +2076,7 @@ class ClassificationProjectDecorr(ClassificationProject): ].sum() ) for label in range(len(binning)) - ] + ]) self._class_weight_target.append(class_weight) return self._class_weight_target @@ -2099,6 +2100,11 @@ class ClassificationProjectDecorr(ClassificationProject): w_list[i] = np.array(w_list[i]) # set signal weights to 0 for decorr target w_list[i][y[:,0]==1] = 0. + bin_labels = np.argmax( + self.get_output_list(y[y[:,0]==0])[1], + axis=1 + ) + w_list[i][y[:,0]==0] *= self.class_weight_target[i-1][bin_labels] return w_list -- GitLab