Skip to content
Snippets Groups Projects
Commit 96759ea3 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

fixing range in plot_NN_2D

parent b8d18de9
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,12 @@ from KerasROOTClassification.plotting import ( ...@@ -20,7 +20,12 @@ from KerasROOTClassification.plotting import (
plot_cond_avg_actmax_2D, plot_cond_avg_actmax_2D,
plot_NN_vs_var_2D_all, plot_NN_vs_var_2D_all,
) )
from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events, weighted_quantile from KerasROOTClassification.utils import (
get_single_neuron_function,
get_max_activation_events,
weighted_quantile,
get_ranges
)
parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron') parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron')
parser.add_argument("project_dir") parser.add_argument("project_dir")
...@@ -89,14 +94,6 @@ except NameError: ...@@ -89,14 +94,6 @@ except NameError:
# percentilesx = weighted_quantile(varx_test[x_not_masked], [0.01, 0.99], sample_weight=total_weights[x_not_masked]) # percentilesx = weighted_quantile(varx_test[x_not_masked], [0.01, 0.99], sample_weight=total_weights[x_not_masked])
# percentilesy = weighted_quantile(vary_test[y_not_masked], [0.01, 0.99], sample_weight=total_weights[y_not_masked]) # percentilesy = weighted_quantile(vary_test[y_not_masked], [0.01, 0.99], sample_weight=total_weights[y_not_masked])
def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None):
ranges = []
for var_index in range(x.shape[1]):
x_var = x[:,var_index]
not_masked = np.where(x_var != mask_value)[0]
ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))
return ranges
percentilesx = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=varx_index)[0] percentilesx = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=varx_index)[0]
percentilesy = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=vary_index)[0] percentilesy = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=vary_index)[0]
......
...@@ -39,6 +39,18 @@ def create_random_event(ranges): ...@@ -39,6 +39,18 @@ def create_random_event(ranges):
return random_event return random_event
def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None):
"Get ranges for plotting or random event generation based on quantiles"
ranges = []
for var_index in range(x.shape[1]):
if (filter_index is not None) and (var_index != filter_index):
continue
x_var = x[:,var_index]
not_masked = np.where(x_var != mask_value)[0]
ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))
return ranges
def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[], def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[],
input_transform=None, input_inverse_transform=None): input_transform=None, input_inverse_transform=None):
if input_transform is not None: if input_transform is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment