Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • Eric.Schanet/KerasROOTClassification
  • Nikolai.Hartmann/KerasROOTClassification
2 results
Show changes
Commits on Source (2)
......@@ -651,6 +651,7 @@ class ClassificationProject(object):
elif self.scaler_type == "WeightedRobustScaler":
self._scaler = WeightedRobustScaler()
scaler_fit_kwargs["weights"] = self.w_train_tot
scaler_fit_kwargs["mask_value"] = self.mask_value
else:
raise ValueError("Scaler type {} unknown".format(self.scaler_type))
logger.info("Fitting {} to training data".format(self.scaler_type))
......@@ -1334,6 +1335,10 @@ class ClassificationProject(object):
break
if self.target_fields:
y = y[0]
try:
x = self.get_input_flat(x)
except NameError:
pass
bkg_batch = x[:,var_index][y==0]
sig_batch = x[:,var_index][y==1]
bkg_weights_batch = w[y==0]
......
......@@ -197,14 +197,26 @@ def weighted_quantile(values, quantiles, sample_weight=None, values_sorted=False
class WeightedRobustScaler(RobustScaler):
def fit(self, X, y=None, weights=None):
if not np.isnan(X).any():
def fit(self, X, y=None, weights=None, mask_value=None):
if not np.isnan(X).any() and mask_value is not None and weights is None:
# these checks don't work for nan values
super(WeightedRobustScaler, self).fit(X, y)
if weights is None:
return self
return super(WeightedRobustScaler, self).fit(X, y)
else:
wqs = np.array([weighted_quantile(X[:,i][~np.isnan(X[:,i])], [0.25, 0.5, 0.75], sample_weight=weights) for i in range(X.shape[1])])
if weights is None:
weights = np.ones(len(self.X))
wqs = []
for i in range(X.shape[1]):
mask = ~np.isnan(X[:,i])
if mask_value is not None:
mask &= (X[:,i] != mask_value)
wqs.append(
weighted_quantile(
X[:,i][mask],
[0.25, 0.5, 0.75],
sample_weight=weights[mask]
)
)
wqs = np.array(wqs)
self.center_ = wqs[:,1]
self.scale_ = wqs[:,2]-wqs[:,0]
self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False)
......