diff --git a/utils.py b/utils.py index adf213f7edb71086fa8e9148a77062565f098f0e..1da306ca79f3b350d9e12708c942254ddebea067 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,7 @@ import numpy as np import keras.backend as K from sklearn.preprocessing import RobustScaler +from sklearn.preprocessing.data import _handle_zeros_in_scale from meme import cache @@ -140,5 +141,6 @@ class WeightedRobustScaler(RobustScaler): wqs = np.array([weighted_quantile(X[:,i], [0.25, 0.5, 0.75], sample_weight=weights) for i in range(X.shape[1])]) self.center_ = wqs[:,1] self.scale_ = wqs[:,2]-wqs[:,0] + self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) return self