From 0bd404e7dade3ce12ea0336d021fbee29270f85c Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Fri, 30 Nov 2018 12:08:28 +0100 Subject: [PATCH] adding class_weight_target property --- toolkit.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/toolkit.py b/toolkit.py index b3c1146..61fd512 100755 --- a/toolkit.py +++ b/toolkit.py @@ -2038,6 +2038,7 @@ class ClassificationProjectDecorr(ClassificationProject): self._class_output = None self._adv_outputs = None self._model_adv = None + self._class_weight_target = None def load(self, *args, **kwargs): @@ -2059,6 +2060,26 @@ class ClassificationProjectDecorr(ClassificationProject): ) + @property + def class_weight_target(self): + if self._class_weight_target is None: + self._class_weight_target = [] + for var_i, binning in enumerate(self.decorr_binnings, 1): + sumw = self.w_train[self.l_train==0].sum() + class_weight = [ + sumw/( + len(binning) + * self.w_train[ + (np.argmax(self.get_output_list(self.y_train)[var_i], axis=1) == label) + & (self.l_train == 0) + ].sum() + ) + for label in range(len(binning)) + ] + self._class_weight_target.append(class_weight) + return self._class_weight_target + + def get_output_list(self, y): out_list = super(ClassificationProjectDecorr, self).get_output_list(y) for i, (out, binning) in enumerate( -- GitLab