diff --git a/plotting.py b/plotting.py
index d1d1d286ec54294c368033ed658da42ea3ccb883..101fad04491a37b376b67b6876234ad00f64632f 100644
--- a/plotting.py
+++ b/plotting.py
@@ -235,7 +235,7 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
 
 
 def plot_profile_2D_all(plotname, model, events,
-                        varx_index, vary_index,
+                        valsx, valsy,
                         nbinsx, xmin, xmax,
                         nbinsy, ymin, ymax,
                         transform_function=None,
@@ -246,16 +246,15 @@ def plot_profile_2D_all(plotname, model, events,
                         log_default_ymin=1e-5,
                         cmap="inferno", **kwargs):
 
-    "Similar to plot_NN_2D, but creates a grid of plots for all neurons."
-
-    valsx = np.array(events[:,varx_index])
-    valsy = np.array(events[:,vary_index])
+    "Similar to plot_profile_2D, but creates a grid of plots for all neurons."
 
     # transform
     if transform_function is not None:
         events = transform_function(events)
 
+    logger.info("Reading activations for all neurons")
     acts = get_activations(model, events, print_shape_only=True)
+    logger.info("Done")
 
     if plot_last_layer:
         n_neurons = [len(i[0]) for i in acts]
@@ -275,21 +274,12 @@ def plot_profile_2D_all(plotname, model, events,
     grid_array = np.array(grid)
     grid_array = grid_array.reshape(*nrows_ncols[::-1])
 
-    # leave out the last layer
-    global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]])
-    global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]])
-
-    logger.info("global_min: {}".format(global_min))
-    logger.info("global_max: {}".format(global_max))
-
-    output_min_default = 0
-    output_max_default = 1
-
-    if global_min <= 0 and logz:
-        global_min = log_default_ymin
-        logger.info("Changing global_min to {}".format(log_default_ymin))
+    global_min = None
+    global_max = None
 
+    logger.info("Creating profile histograms")
     ims = []
+    reg_plots = []
     for layer in range(layers):
         for neuron in range(len(acts[layer][0])):
             acts_neuron = acts[layer][:,neuron]
@@ -305,27 +295,46 @@ def plot_profile_2D_all(plotname, model, events,
                     extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
                 else:
                     extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
-            #im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
             hist, xedges, yedges = get_profile_2D(
                 valsx, valsy, acts_neuron,
                 nbinsx, xmin, xmax,
                 nbinsy, ymin, ymax,
                 **kwargs
             )
+            if global_min is None or hist.min() < global_min:
+                global_min = hist.min()
+            if global_max is None or hist.max() > global_max:
+                global_max = hist.max()
             X, Y = np.meshgrid(xedges, yedges)
-            im = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extra_opts)
-            ax.set_facecolor("black")
-            if varx_label is not None:
-                ax.set_xlabel(varx_label)
-            if vary_label is not None:
-                ax.set_ylabel(vary_label)
-            ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
+            reg_plots.append((layer, neuron, ax, (X, Y, hist), dict(cmap="inferno", linewidth=0, rasterized=True, **extra_opts)))
+    logger.info("Done")
+
+    logger.info("global_min: {}".format(global_min))
+    logger.info("global_max: {}".format(global_max))
+
+    if global_min <= 0 and logz:
+        global_min = log_default_ymin
+        logger.info("Changing global_min to {}".format(log_default_ymin))
+
+    for layer, neuron, ax, args, kwargs in reg_plots:
+        if zrange is None:
+            kwargs["norm"].vmin = global_min
+            kwargs["norm"].vmax = global_max
+        im = ax.pcolormesh(*args, **kwargs)
+        ax.set_facecolor("black")
+        if varx_label is not None:
+            ax.set_xlabel(varx_label)
+        if vary_label is not None:
+            ax.set_ylabel(vary_label)
+        ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
 
     cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
     cb.ax.xaxis.set_ticks_position('top')
     cb.ax.xaxis.set_label_position('top')
 
+    logger.info("Rendering")
     save_show(plt, fig, plotname, bbox_inches='tight')
+    logger.info("Done")
 
 
 def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"):