Skip to content
Snippets Groups Projects
Commit 0bd404e7 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adding class_weight_target property

parent 1e44f56c
No related branches found
No related tags found
No related merge requests found
...@@ -2038,6 +2038,7 @@ class ClassificationProjectDecorr(ClassificationProject): ...@@ -2038,6 +2038,7 @@ class ClassificationProjectDecorr(ClassificationProject):
self._class_output = None self._class_output = None
self._adv_outputs = None self._adv_outputs = None
self._model_adv = None self._model_adv = None
self._class_weight_target = None
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
...@@ -2059,6 +2060,26 @@ class ClassificationProjectDecorr(ClassificationProject): ...@@ -2059,6 +2060,26 @@ class ClassificationProjectDecorr(ClassificationProject):
) )
@property
def class_weight_target(self):
if self._class_weight_target is None:
self._class_weight_target = []
for var_i, binning in enumerate(self.decorr_binnings, 1):
sumw = self.w_train[self.l_train==0].sum()
class_weight = [
sumw/(
len(binning)
* self.w_train[
(np.argmax(self.get_output_list(self.y_train)[var_i], axis=1) == label)
& (self.l_train == 0)
].sum()
)
for label in range(len(binning))
]
self._class_weight_target.append(class_weight)
return self._class_weight_target
def get_output_list(self, y): def get_output_list(self, y):
out_list = super(ClassificationProjectDecorr, self).get_output_list(y) out_list = super(ClassificationProjectDecorr, self).get_output_list(y)
for i, (out, binning) in enumerate( for i, (out, binning) in enumerate(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment