From 79807735b1966a5e3cb978cb191f037e4185e561 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Fri, 14 Dec 2018 10:51:21 +0100 Subject: [PATCH] fix WeightedRobustScaler --- toolkit.py | 1 + utils.py | 24 ++++++++++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/toolkit.py b/toolkit.py index 0f804b7..8cd9c19 100755 --- a/toolkit.py +++ b/toolkit.py @@ -651,6 +651,7 @@ class ClassificationProject(object): elif self.scaler_type == "WeightedRobustScaler": self._scaler = WeightedRobustScaler() scaler_fit_kwargs["weights"] = self.w_train_tot + scaler_fit_kwargs["mask_value"] = self.mask_value else: raise ValueError("Scaler type {} unknown".format(self.scaler_type)) logger.info("Fitting {} to training data".format(self.scaler_type)) diff --git a/utils.py b/utils.py index a353184..1af350e 100644 --- a/utils.py +++ b/utils.py @@ -197,14 +197,26 @@ def weighted_quantile(values, quantiles, sample_weight=None, values_sorted=False class WeightedRobustScaler(RobustScaler): - def fit(self, X, y=None, weights=None): - if not np.isnan(X).any(): + def fit(self, X, y=None, weights=None, mask_value=None): + if not np.isnan(X).any() and mask_value is not None and weights is None: # these checks don't work for nan values - super(WeightedRobustScaler, self).fit(X, y) - if weights is None: - return self + return super(WeightedRobustScaler, self).fit(X, y) else: - wqs = np.array([weighted_quantile(X[:,i][~np.isnan(X[:,i])], [0.25, 0.5, 0.75], sample_weight=weights) for i in range(X.shape[1])]) + if weights is None: + weights = np.ones(len(self.X)) + wqs = [] + for i in range(X.shape[1]): + mask = ~np.isnan(X[:,i]) + if mask_value is not None: + mask &= (X[:,i] != mask_value) + wqs.append( + weighted_quantile( + X[:,i][mask], + [0.25, 0.5, 0.75], + sample_weight=weights[mask] + ) + ) + wqs = np.array(wqs) self.center_ = wqs[:,1] self.scale_ = wqs[:,2]-wqs[:,0] self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) -- GitLab