From 9d83f5ebdc108b6c7002ed829cbd9bf72c4ed787 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 21 Aug 2018 10:23:26 +0200
Subject: [PATCH] trying to consistently treat masking

---
 plotting.py           | 11 +++++++++--
 scripts/plot_NN_2D.py | 24 +++++++++++++++++-------
 2 files changed, 26 insertions(+), 9 deletions(-)

diff --git a/plotting.py b/plotting.py
index bfb906c..f0dd64e 100644
--- a/plotting.py
+++ b/plotting.py
@@ -32,8 +32,15 @@ def save_show(plt, fig, filename):
         return None
 
 
-def get_mean_event(x, y, class_label):
-    return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])]
+def get_mean_event(x, y, class_label, mask_value=None):
+    means = []
+    for var_index in range(x.shape[1]):
+        if mask_value is not None:
+            masked_values = np.where(x[:,var_index] == mask_value)[0]
+            x = x[masked_values]
+            y = y[masked_values]
+        means.append(np.mean(x[y==class_label][:,var_index]))
+    return means
 
 
 def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None):
diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 04016cb..460934d 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -73,13 +73,21 @@ else:
 varx_label = args.varx
 vary_label = args.vary
 
-# percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
-# percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
-
 total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)]
 
-percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights)
-percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights)
+try:
+    mask_value = c.mask_value
+except NameError:
+    mask_value = None
+
+varx_test = c.x_test[:,varx_index]
+vary_test = c.x_test[:,vary_index]
+
+x_not_masked = np.where(varx_test != mask_value)[0]
+y_not_masked = np.where(vary_test != mask_value)[0]
+
+percentilesx = weighted_quantile(varx_test[x_not_masked], [0.1, 0.99], sample_weight=total_weights[x_not_masked])
+percentilesy = weighted_quantile(vary_test[y_not_masked], [0.1, 0.99], sample_weight=total_weights[y_not_masked])
 
 if args.xrange is not None:
     if len(args.xrange) < 3:
@@ -100,9 +108,11 @@ else:
 if args.mode.startswith("mean"):
 
     if args.mode == "mean_sig":
-        means = get_mean_event(c.x_test, c.y_test, 1)
+        means = get_mean_event(c.x_test, c.y_test, 1, mask_value=mask_value)
     elif args.mode == "mean_bkg":
-        means = get_mean_event(c.x_test, c.y_test, 0)
+        means = get_mean_event(c.x_test, c.y_test, 0, mask_value=mask_value)
+
+    print(means)
 
     if hasattr(c, "get_input_list"):
         input_transform = c.get_input_list
-- 
GitLab