diff --git a/plotting.py b/plotting.py
index db24def073d5f9aaa7a1eb55f86e3e508ec3bb11..ae25dfc096feffa92f3470e134155b7439638f6c 100644
--- a/plotting.py
+++ b/plotting.py
@@ -120,11 +120,13 @@ def plot_NN_vs_var_2D(plotname, means,
 
 
 def plot_NN_vs_var_2D_all(plotname, model, means,
-                          var1_index, var1_range,
-                          var2_index, var2_range,
+                          varx_index,
+                          vary_index,
+                          nbinsx, xmin, xmax,
+                          nbinsy, ymin, ymax,
                           transform_function=None,
-                          var1_label=None,
-                          var2_label=None,
+                          varx_label=None,
+                          vary_label=None,
                           zrange=None, logz=False,
                           plot_last_layer=False,
                           log_default_ymin=1e-5,
@@ -132,15 +134,15 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
 
     "Similar to plot_NN_vs_var_2D, but creates a grid of plots for all neurons."
 
-    var1_vals = np.arange(*var1_range)
-    var2_vals = np.arange(*var2_range)
+    varx_vals = np.linspace(xmin, xmax, nbinsx)
+    vary_vals = np.linspace(ymin, ymax, nbinsy)
 
     # create the events for which we want to fetch the activations
-    events = np.tile(means, len(var1_vals)*len(var2_vals)).reshape(len(var2_vals), len(var1_vals), -1)
-    for i, y in enumerate(var2_vals):
-        for j, x in enumerate(var1_vals):
-            events[i][j][var1_index] = x
-            events[i][j][var2_index] = y
+    events = np.tile(means, len(varx_vals)*len(vary_vals)).reshape(len(vary_vals), len(varx_vals), -1)
+    for i, y in enumerate(vary_vals):
+        for j, x in enumerate(varx_vals):
+            events[i][j][varx_index] = x
+            events[i][j][vary_index] = y
 
     # convert back into 1d array
     events = events.reshape(-1, len(means))
@@ -187,7 +189,7 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
     for layer in range(layers):
         for neuron in range(len(acts[layer][0])):
             acts_neuron = acts[layer][:,neuron]
-            acts_neuron = acts_neuron.reshape(len(var2_vals), len(var1_vals))
+            acts_neuron = acts_neuron.reshape(len(vary_vals), len(varx_vals))
             ax = grid_array[neuron][layer]
             extra_opts = {}
             if not (plot_last_layer and layer == layers-1):
@@ -200,12 +202,12 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
                     extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
                 else:
                     extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
-            im = ax.pcolormesh(var1_vals, var2_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
+            im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
             ax.set_facecolor("black")
-            if var1_label is not None:
-                ax.set_xlabel(var1_label)
-            if var2_label is not None:
-                ax.set_ylabel(var2_label)
+            if varx_label is not None:
+                ax.set_xlabel(varx_label)
+            if vary_label is not None:
+                ax.set_ylabel(vary_label)
             ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
 
     cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
@@ -342,6 +344,8 @@ if __name__ == "__main__":
 
     def test_mean_signal():
 
+        c._load_data() # untransformed
+
         mean_signal = get_mean_event(c.x_test, c.y_test, 1)
 
         print("Mean signal: ")
@@ -371,9 +375,11 @@ if __name__ == "__main__":
 
         plot_NN_vs_var_2D_all("mt_vs_met_all.pdf", means=mean_signal,
                               model=c.model, transform_function=c.scaler.transform,
-                              var1_index=c.fields.index("met"), var1_range=(0, 1000, 10),
-                              var2_index=c.fields.index("mt"), var2_range=(0, 500, 10),
-                              var1_label="met [GeV]", var2_label="mt [GeV]")
+                              varx_index=c.fields.index("met"),
+                              vary_index=c.fields.index("mt"),
+                              nbinsx=100, xmin=0, xmax=1000,
+                              nbinsy=100, ymin=0, ymax=500,
+                              varx_label="met [GeV]", vary_label="mt [GeV]")
 
         plot_NN_vs_var_2D("mt_vs_met_crosscheck.pdf", means=mean_signal,
                           scorefun=get_single_neuron_function(c.model, layer=3, neuron=0, scaler=c.scaler),
diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index ba54b1eb6551d2bb2baefb985d9cbfc7b43557b3..460148f0cb802b1417382ef5edf731401a07494c 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -7,13 +7,18 @@ logging.basicConfig()
 
 import numpy as np
 
+import ROOT
+ROOT.gROOT.SetBatch()
+ROOT.PyConfig.IgnoreCommandLineOptions = True
+
 from KerasROOTClassification import ClassificationProject
 from KerasROOTClassification.plotting import (
     get_mean_event,
     plot_NN_vs_var_2D,
     plot_profile_2D,
     plot_hist_2D_events,
-    plot_cond_avg_actmax_2D
+    plot_cond_avg_actmax_2D,
+    plot_NN_vs_var_2D_all,
 )
 from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events
 
@@ -27,6 +32,7 @@ parser.add_argument("-m", "--mode",
                     default="mean_sig")
 parser.add_argument("-l", "--layer", type=int, help="Layer index (takes last layer by default)")
 parser.add_argument("-n", "--neuron", type=int, default=0, help="Neuron index (takes first neuron by default)")
+parser.add_argument("-a", "--all-neurons", action="store_true", help="Create a summary plot for all neurons in all hidden layers")
 parser.add_argument("--log", action="store_true", help="Plot in color in log scale")
 parser.add_argument("--contour", action="store_true", help="Interpolate with contours")
 parser.add_argument("-b", "--nbins", default=20, type=int, help="Number of bins in x and y direction")
@@ -42,6 +48,9 @@ parser.add_argument("-s", "--step-size", help="step size for activation maximisa
 
 args = parser.parse_args()
 
+if args.all_neurons and (not args.mode.startswith("mean")):
+    parser.error("--all-neurons currently only supported for mean_sig and mean_bkg")
+
 if args.verbose:
     logging.getLogger().setLevel(logging.DEBUG)
 
@@ -90,17 +99,31 @@ if args.mode.startswith("mean"):
     elif args.mode == "mean_bkg":
         means = get_mean_event(c.x_test, c.y_test, 0)
 
-    plot_NN_vs_var_2D(
-        args.output_filename,
-        means=means,
-        varx_index=varx_index,
-        vary_index=vary_index,
-        scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
-        xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
-        ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
-        varx_label=varx_label, vary_label=vary_label,
-        logscale=args.log, only_pixels=(not args.contour)
-    )
+    if not args.all_neurons:
+        plot_NN_vs_var_2D(
+            args.output_filename,
+            means=means,
+            varx_index=varx_index,
+            vary_index=vary_index,
+            scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
+            xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
+            ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
+            varx_label=varx_label, vary_label=vary_label,
+            logscale=args.log, only_pixels=(not args.contour)
+        )
+    else:
+        plot_NN_vs_var_2D_all(
+            args.output_filename,
+            means=means,
+            model=c.model,
+            transform_function=c.scaler.transform,
+            varx_index=varx_index,
+            vary_index=vary_index,
+            xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
+            ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
+            logz=args.log,
+            plot_last_layer=False,
+        )
 
 elif args.mode.startswith("profile"):