From 79807735b1966a5e3cb978cb191f037e4185e561 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Fri, 14 Dec 2018 10:51:21 +0100
Subject: [PATCH] fix WeightedRobustScaler

---
 toolkit.py |  1 +
 utils.py   | 24 ++++++++++++++++++------
 2 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 0f804b7..8cd9c19 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -651,6 +651,7 @@ class ClassificationProject(object):
                 elif self.scaler_type == "WeightedRobustScaler":
                     self._scaler = WeightedRobustScaler()
                     scaler_fit_kwargs["weights"] = self.w_train_tot
+                    scaler_fit_kwargs["mask_value"] = self.mask_value
                 else:
                     raise ValueError("Scaler type {} unknown".format(self.scaler_type))
                 logger.info("Fitting {} to training data".format(self.scaler_type))
diff --git a/utils.py b/utils.py
index a353184..1af350e 100644
--- a/utils.py
+++ b/utils.py
@@ -197,14 +197,26 @@ def weighted_quantile(values, quantiles, sample_weight=None, values_sorted=False
 
 class WeightedRobustScaler(RobustScaler):
 
-    def fit(self, X, y=None, weights=None):
-        if not np.isnan(X).any():
+    def fit(self, X, y=None, weights=None, mask_value=None):
+        if not np.isnan(X).any() and mask_value is not None and weights is None:
             # these checks don't work for nan values
-            super(WeightedRobustScaler, self).fit(X, y)
-        if weights is None:
-            return self
+            return super(WeightedRobustScaler, self).fit(X, y)
         else:
-            wqs = np.array([weighted_quantile(X[:,i][~np.isnan(X[:,i])], [0.25, 0.5, 0.75], sample_weight=weights) for i in range(X.shape[1])])
+            if weights is None:
+                weights = np.ones(len(self.X))
+            wqs = []
+            for i in range(X.shape[1]):
+                mask = ~np.isnan(X[:,i])
+                if mask_value is not None:
+                    mask &= (X[:,i] != mask_value)
+                wqs.append(
+                    weighted_quantile(
+                        X[:,i][mask],
+                        [0.25, 0.5, 0.75],
+                        sample_weight=weights[mask]
+                    )
+                )
+            wqs = np.array(wqs)
             self.center_ = wqs[:,1]
             self.scale_ = wqs[:,2]-wqs[:,0]
             self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False)
-- 
GitLab