diff --git a/toolkit.py b/toolkit.py index 5280d07aedec891605415b0283fd12882bbc1017..09fcacd840250896811b076ee5fdce32f3cb05fc 100755 --- a/toolkit.py +++ b/toolkit.py @@ -236,7 +236,8 @@ class ClassificationProject(object): random_seed=1234, balance_dataset=False, loss='binary_crossentropy', - mask_value=None): + mask_value=None, + apply_class_weight=True): self.name = name self.signal_trees = signal_trees @@ -308,6 +309,7 @@ class ClassificationProject(object): self.balance_dataset = balance_dataset self.loss = loss self.mask_value = mask_value + self.apply_class_weight = apply_class_weight self.s_train = None self.b_train = None @@ -804,7 +806,10 @@ class ClassificationProject(object): else: class_weight = self.balanced_class_weight if self._w_train_tot is None: - self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] + if self.apply_class_weight: + self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] + else: + self._w_train_tot = np.array(self.w_train) self._w_train_tot /= np.mean(self._w_train_tot) return self._w_train_tot