diff --git a/utils.py b/utils.py index 48f58f60b909b912d8b47038fd732c1ce339a05c..a353184806feebc438ec1f19a5c831d7b0ecbb5f 100644 --- a/utils.py +++ b/utils.py @@ -45,14 +45,19 @@ def create_random_event(ranges, mask_probs=None, mask_value=None): return random_event -def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None): +def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None, max_evts=None): "Get ranges for plotting or random event generation based on quantiles" ranges = [] mask_probs = [] + if max_evts is not None: + rnd_idx = np.random.permutation(np.arange(len(x))) + rnd_idx = rnd_idx[:max_evts] for var_index in range(x.shape[1]): if (filter_index is not None) and (var_index != filter_index): continue x_var = x[:,var_index] + if max_evts is not None: + x_var = x_var[rnd_idx] not_masked = np.where(x_var != mask_value)[0] masked = np.where(x_var == mask_value)[0] ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))