diff --git a/toolkit.py b/toolkit.py index 61fd512d71e36d9f8469d792052d435c2f509423..947968c6a6d250db87d27511f125cad583198915 100755 --- a/toolkit.py +++ b/toolkit.py @@ -2062,11 +2062,12 @@ class ClassificationProjectDecorr(ClassificationProject): @property def class_weight_target(self): + "to weight the adversarial target to have equal sum of weights per bin" if self._class_weight_target is None: self._class_weight_target = [] for var_i, binning in enumerate(self.decorr_binnings, 1): sumw = self.w_train[self.l_train==0].sum() - class_weight = [ + class_weight = np.array([ sumw/( len(binning) * self.w_train[ @@ -2075,7 +2076,7 @@ class ClassificationProjectDecorr(ClassificationProject): ].sum() ) for label in range(len(binning)) - ] + ]) self._class_weight_target.append(class_weight) return self._class_weight_target @@ -2099,6 +2100,11 @@ class ClassificationProjectDecorr(ClassificationProject): w_list[i] = np.array(w_list[i]) # set signal weights to 0 for decorr target w_list[i][y[:,0]==1] = 0. + bin_labels = np.argmax( + self.get_output_list(y[y[:,0]==0])[1], + axis=1 + ) + w_list[i][y[:,0]==0] *= self.class_weight_target[i-1][bin_labels] return w_list