From 9ee7e3277b913f2a7b14a54424ac1cc0646cd1a4 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 21 Aug 2018 15:20:41 +0200
Subject: [PATCH] weighted quantiles for all functions in plot_NN_2D

---
 scripts/plot_NN_2D.py | 34 +++++++++++++++++++++++-----------
 1 file changed, 23 insertions(+), 11 deletions(-)

diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index a5b65b0..680d657 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -80,14 +80,26 @@ try:
 except NameError:
     mask_value = None
 
-varx_test = c.x_test[:,varx_index]
-vary_test = c.x_test[:,vary_index]
+# varx_test = c.x_test[:,varx_index]
+# vary_test = c.x_test[:,vary_index]
 
-x_not_masked = np.where(varx_test != mask_value)[0]
-y_not_masked = np.where(vary_test != mask_value)[0]
+# x_not_masked = np.where(varx_test != mask_value)[0]
+# y_not_masked = np.where(vary_test != mask_value)[0]
+
+# percentilesx = weighted_quantile(varx_test[x_not_masked], [0.01, 0.99], sample_weight=total_weights[x_not_masked])
+# percentilesy = weighted_quantile(vary_test[y_not_masked], [0.01, 0.99], sample_weight=total_weights[y_not_masked])
+
+def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None):
+    ranges = []
+    for var_index in range(x.shape[1]):
+        x_var = x[:,var_index]
+        not_masked = np.where(x_var != mask_value)[0]
+        ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))
+    return ranges
+
+percentilesx = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=varx_index)[0]
+percentilesy = get_ranges(c.x_test, [0.01, 0.99], total_weights, mask_value=mask_value, filter_index=vary_index)[0]
 
-percentilesx = weighted_quantile(varx_test[x_not_masked], [0.01, 0.99], sample_weight=total_weights[x_not_masked])
-percentilesy = weighted_quantile(vary_test[y_not_masked], [0.01, 0.99], sample_weight=total_weights[y_not_masked])
 
 if args.xrange is not None:
     if len(args.xrange) < 3:
@@ -206,10 +218,10 @@ elif args.mode.startswith("hist"):
         weights = c.w_test[c.y_test==class_index]
     else:
         # ranges in which to sample the random events
-        x_test_scaled = c.scaler.transform(c.x_test)
-        ranges = [np.percentile(x_test_scaled[:,var_index], [1,99]) for var_index in range(len(c.fields))]
+        x_test_scaled = c.transform(c.x_test)
+        ranges = get_ranges(x_test_scaled, [0.01, 0.99], total_weights, mask_value=mask_value)
         losses, events = get_max_activation_events(c.model, ranges, ntries=args.ntries_actmax, step=args.step_size, layer=layer, neuron=neuron, threshold=args.threshold)
-        events = c.scaler.inverse_transform(events)
+        events = c.inverse_transform(events)
         valsx = events[:,varx_index]
         if not plot_vs_activation:
             valsy = events[:,vary_index]
@@ -229,10 +241,10 @@ elif args.mode.startswith("hist"):
 
 elif args.mode.startswith("cond_actmax"):
 
-    x_test_scaled = c.scaler.transform(c.x_test)
+    x_test_scaled = c.transform(c.x_test)
 
     # ranges in which to sample the random events
-    ranges = [np.percentile(x_test_scaled[:,var_index], [1,99]) for var_index in range(len(c.fields))]
+    ranges = get_ranges(x_test_scaled, [0.01, 0.99], total_weights, mask_value=mask_value)
 
     plot_cond_avg_actmax_2D(
         args.output_filename,
-- 
GitLab