diff --git a/plotting.py b/plotting.py
index bfb906c06850f0ddc462909dceb8963bac0b8212..f0dd64e82e9b7726e16772e6cee56b816add778a 100644
--- a/plotting.py
+++ b/plotting.py
@@ -32,8 +32,15 @@ def save_show(plt, fig, filename):
         return None
 
 
-def get_mean_event(x, y, class_label):
-    return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])]
+def get_mean_event(x, y, class_label, mask_value=None):
+    means = []
+    for var_index in range(x.shape[1]):
+        if mask_value is not None:
+            masked_values = np.where(x[:,var_index] == mask_value)[0]
+            x = x[masked_values]
+            y = y[masked_values]
+        means.append(np.mean(x[y==class_label][:,var_index]))
+    return means
 
 
 def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None):
diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 04016cbd396e9e22ca2f3d76ab9dc5f167911a24..460934daca75730d6f3d0df7b2d881af6c1595af 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -73,13 +73,21 @@ else:
 varx_label = args.varx
 vary_label = args.vary
 
-# percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
-# percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
-
 total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)]
 
-percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights)
-percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights)
+try:
+    mask_value = c.mask_value
+except NameError:
+    mask_value = None
+
+varx_test = c.x_test[:,varx_index]
+vary_test = c.x_test[:,vary_index]
+
+x_not_masked = np.where(varx_test != mask_value)[0]
+y_not_masked = np.where(vary_test != mask_value)[0]
+
+percentilesx = weighted_quantile(varx_test[x_not_masked], [0.1, 0.99], sample_weight=total_weights[x_not_masked])
+percentilesy = weighted_quantile(vary_test[y_not_masked], [0.1, 0.99], sample_weight=total_weights[y_not_masked])
 
 if args.xrange is not None:
     if len(args.xrange) < 3:
@@ -100,9 +108,11 @@ else:
 if args.mode.startswith("mean"):
 
     if args.mode == "mean_sig":
-        means = get_mean_event(c.x_test, c.y_test, 1)
+        means = get_mean_event(c.x_test, c.y_test, 1, mask_value=mask_value)
     elif args.mode == "mean_bkg":
-        means = get_mean_event(c.x_test, c.y_test, 0)
+        means = get_mean_event(c.x_test, c.y_test, 0, mask_value=mask_value)
+
+    print(means)
 
     if hasattr(c, "get_input_list"):
         input_transform = c.get_input_list