diff --git a/plotting.py b/plotting.py
index 210cab1db657852fc0634f1443239baee4097111..be7bce236b3d6b9efa010b83362a4dd0dbcadbd5 100644
--- a/plotting.py
+++ b/plotting.py
@@ -11,8 +11,6 @@ import numpy as np
 
 from .keras_visualize_activations.read_activations import get_activations
 
-import meme
-
 """
 Some further plotting functions
 """
@@ -107,14 +105,12 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
                           var2_index, var2_range,
                           transform_function=None,
                           var1_label=None,
-                          var2_label=None):
+                          var2_label=None,
+                          zrange=None, logz=False,
+                          plot_last_layer=False):
 
     "Similar to plot_NN_vs_var_2D, but creates a grid of plots for all neurons."
 
-    # var1 = "lep1Phi"
-    # var2 = "met_Phi"
-    # # var1_vals = np.arange(-3.15,3.15,0.1)
-    # # var2_vals = np.arange(-3.15,3.15,0.1)
     var1_vals = np.arange(*var1_range)
     var2_vals = np.arange(*var2_range)
 
@@ -130,34 +126,70 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
 
     # transform
     if transform_function is not None:
-        #events = c.scaler.transform(events)
         events = transform_function(events)
 
     acts = get_activations(model, events, print_shape_only=True)
 
     aspect = (var1_vals[-1]-var1_vals[0])/(var2_vals[-1]-var2_vals[0])
 
-    n_neurons = [len(i[0]) for i in acts]
+    if plot_last_layer:
+        n_neurons = [len(i[0]) for i in acts]
+    else:
+        n_neurons = [len(i[0]) for i in acts[:-1]]
     layers = len(n_neurons)
 
     nrows_ncols = (layers, max(n_neurons))
     fig = plt.figure(1, figsize=nrows_ncols)
-    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0, label_mode="1")
-    grid = np.array(grid)
-    grid = grid.reshape(*nrows_ncols[::-1])
+    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0,
+                     label_mode="1",
+                     cbar_location="top",
+                     cbar_mode="single",)
+    grid_array = np.array(grid)
+    grid_array = grid_array.reshape(*nrows_ncols[::-1])
+
+    # leave out the last layer
+    global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]])
+    global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]])
+
+    print("global_min: {}".format(global_min))
+    print("global_max: {}".format(global_max))
+
+    output_min_default = 0
+    output_max_default = 1
+
+    if global_min <= 0 and logz:
+        min_exponent = -5
+        global_min = 10**min_exponent
+        output_min_default = global_min
+        print("Changing global_min to {}".format(global_min))
 
     for layer in range(layers):
         for neuron in range(len(acts[layer][0])):
             acts_neuron = acts[layer][:,neuron]
             acts_neuron = acts_neuron.reshape(len(var2_vals), len(var1_vals))
-            ax = grid[neuron][layer]
-            ax.imshow(acts_neuron, origin="lower", extent=[var1_vals[0], var1_vals[-1], var2_vals[0], var2_vals[-1]], aspect=aspect, cmap="jet")
+            ax = grid_array[neuron][layer]
+            extra_opts = {}
+            if not (plot_last_layer and layer == layers-1):
+                # for hidden layers, plot the same z-scale
+                if logz:
+                    norm = matplotlib.colors.LogNorm
+                else:
+                    norm = matplotlib.colors.Normalize
+                if zrange is not None:
+                    extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
+                else:
+                    extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
+            im = ax.imshow(acts_neuron, origin="lower", extent=[var1_vals[0], var1_vals[-1], var2_vals[0], var2_vals[-1]], aspect=aspect, cmap="jet", **extra_opts)
             if var1_label is not None:
                 ax.set_xlabel(var1_label)
             if var2_label is not None:
                 ax.set_ylabel(var2_label)
             ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes)
 
+    cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
+    cb.ax.xaxis.set_ticks_position('top')
+    cb.ax.xaxis.set_label_position('top')
+
     fig.savefig(plotname, bbox_inches='tight')
     plt.close(fig)