From 0a1870dbd101496e06bd585293c4a18b7908d1f1 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 28 Aug 2018 08:59:07 +0200 Subject: [PATCH] option to turn off class weight --- toolkit.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/toolkit.py b/toolkit.py index 5280d07..09fcacd 100755 --- a/toolkit.py +++ b/toolkit.py @@ -236,7 +236,8 @@ class ClassificationProject(object): random_seed=1234, balance_dataset=False, loss='binary_crossentropy', - mask_value=None): + mask_value=None, + apply_class_weight=True): self.name = name self.signal_trees = signal_trees @@ -308,6 +309,7 @@ class ClassificationProject(object): self.balance_dataset = balance_dataset self.loss = loss self.mask_value = mask_value + self.apply_class_weight = apply_class_weight self.s_train = None self.b_train = None @@ -804,7 +806,10 @@ class ClassificationProject(object): else: class_weight = self.balanced_class_weight if self._w_train_tot is None: - self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] + if self.apply_class_weight: + self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] + else: + self._w_train_tot = np.array(self.w_train) self._w_train_tot /= np.mean(self._w_train_tot) return self._w_train_tot -- GitLab