From 96759ea306a87d31483f40c7576690517057586a Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 21 Aug 2018 18:01:01 +0200
Subject: [PATCH] fixing range in plot_NN_2D

---
 scripts/plot_NN_2D.py | 15 ++++++---------
 utils.py              | 12 ++++++++++++
 2 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 0a0be89..1555095 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -20,7 +20,12 @@ from KerasROOTClassification.plotting import (
     plot_cond_avg_actmax_2D,
     plot_NN_vs_var_2D_all,
 )
-from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events, weighted_quantile
+from KerasROOTClassification.utils import (
+    get_single_neuron_function,
+    get_max_activation_events,
+    weighted_quantile,
+    get_ranges
+)
 
 parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron')
 parser.add_argument("project_dir")
@@ -89,14 +94,6 @@ except NameError:
 # percentilesx = weighted_quantile(varx_test[x_not_masked], [0.01, 0.99], sample_weight=total_weights[x_not_masked])
 # percentilesy = weighted_quantile(vary_test[y_not_masked], [0.01, 0.99], sample_weight=total_weights[y_not_masked])
 
-def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None):
-    ranges = []
-    for var_index in range(x.shape[1]):
-        x_var = x[:,var_index]
-        not_masked = np.where(x_var != mask_value)[0]
-        ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))
-    return ranges
-
 percentilesx = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=varx_index)[0]
 percentilesy = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=vary_index)[0]
 
diff --git a/utils.py b/utils.py
index a71f251..f3323c4 100644
--- a/utils.py
+++ b/utils.py
@@ -39,6 +39,18 @@ def create_random_event(ranges):
     return random_event
 
 
+def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None):
+    "Get ranges for plotting or random event generation based on quantiles"
+    ranges = []
+    for var_index in range(x.shape[1]):
+        if (filter_index is not None) and (var_index != filter_index):
+            continue
+        x_var = x[:,var_index]
+        not_masked = np.where(x_var != mask_value)[0]
+        ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))
+    return ranges
+
+
 def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[],
                              input_transform=None, input_inverse_transform=None):
     if input_transform is not None:
-- 
GitLab