diff --git a/toolkit.py b/toolkit.py index a8fe83334a688a1f11d5d4f44332670b60189464..b3c1146cec32e8f09ba4b7e0e1664559af43daf3 100755 --- a/toolkit.py +++ b/toolkit.py @@ -2042,14 +2042,18 @@ class ClassificationProjectDecorr(ClassificationProject): def load(self, *args, **kwargs): super(ClassificationProjectDecorr, self).load(*args, **kwargs) - bin_frac = 1./float(self.decorr_bins) + if not isinstance(self.decorr_bins, list): + bin_frac = 1./float(self.decorr_bins) + decorr_bins = np.arange(bin_frac, 1+bin_frac, bin_frac) + else: + decorr_bins = self.decorr_bins for idx, field_name in enumerate(self.target_fields): # adversary target is fit as multiclass problem with bin indices # (self.decorr_bins quantiles) as labels like in arXiv:1703.03507 self.decorr_binnings.append( weighted_quantile( self.y_train[self.l_train==0][:,idx+1], # bkg only - np.arange(bin_frac, 1+bin_frac, bin_frac), + decorr_bins, sample_weight=self.w_train[self.l_train==0] ) )