diff --git a/toolkit.py b/toolkit.py index 50cfa16fd9ff1f434429fd4f61431e484d4e1c22..bb47fb29baf02c8339d4f93a9f34e0c45e12afae 100755 --- a/toolkit.py +++ b/toolkit.py @@ -173,6 +173,10 @@ class ClassificationProject(object): :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets) + :param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background) + + :param normalize_weights: normalize the weights to mean 1 + """ @@ -237,7 +241,8 @@ class ClassificationProject(object): balance_dataset=False, loss='binary_crossentropy', mask_value=None, - apply_class_weight=True): + apply_class_weight=True, + normalize_weights=True): self.name = name self.signal_trees = signal_trees @@ -310,6 +315,7 @@ class ClassificationProject(object): self.loss = loss self.mask_value = mask_value self.apply_class_weight = apply_class_weight + self.normalize_weights = normalize_weights self.s_train = None self.b_train = None @@ -810,7 +816,8 @@ class ClassificationProject(object): self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] else: self._w_train_tot = np.array(self.w_train) - self._w_train_tot /= np.mean(self._w_train_tot) + if self.normalize_weights: + self._w_train_tot /= np.mean(self._w_train_tot) return self._w_train_tot