From 96759ea306a87d31483f40c7576690517057586a Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 21 Aug 2018 18:01:01 +0200 Subject: [PATCH] fixing range in plot_NN_2D --- scripts/plot_NN_2D.py | 15 ++++++--------- utils.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py index 0a0be89..1555095 100755 --- a/scripts/plot_NN_2D.py +++ b/scripts/plot_NN_2D.py @@ -20,7 +20,12 @@ from KerasROOTClassification.plotting import ( plot_cond_avg_actmax_2D, plot_NN_vs_var_2D_all, ) -from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events, weighted_quantile +from KerasROOTClassification.utils import ( + get_single_neuron_function, + get_max_activation_events, + weighted_quantile, + get_ranges +) parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron') parser.add_argument("project_dir") @@ -89,14 +94,6 @@ except NameError: # percentilesx = weighted_quantile(varx_test[x_not_masked], [0.01, 0.99], sample_weight=total_weights[x_not_masked]) # percentilesy = weighted_quantile(vary_test[y_not_masked], [0.01, 0.99], sample_weight=total_weights[y_not_masked]) -def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None): - ranges = [] - for var_index in range(x.shape[1]): - x_var = x[:,var_index] - not_masked = np.where(x_var != mask_value)[0] - ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked])) - return ranges - percentilesx = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=varx_index)[0] percentilesy = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=vary_index)[0] diff --git a/utils.py b/utils.py index a71f251..f3323c4 100644 --- a/utils.py +++ b/utils.py @@ -39,6 +39,18 @@ def create_random_event(ranges): return random_event +def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None): + "Get ranges for plotting or random event generation based on quantiles" + ranges = [] + for var_index in range(x.shape[1]): + if (filter_index is not None) and (var_index != filter_index): + continue + x_var = x[:,var_index] + not_masked = np.where(x_var != mask_value)[0] + ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked])) + return ranges + + def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[], input_transform=None, input_inverse_transform=None): if input_transform is not None: -- GitLab