diff --git a/plotting.py b/plotting.py
index 6d5f3c4e9ac2e06dde334b70c78ed8c2b4c499de..d0d73155927dbcbaf4668d51b58542b3ded0b5fb 100644
--- a/plotting.py
+++ b/plotting.py
@@ -258,9 +258,9 @@ def plot_profile_2D_all(plotname, model, events,
     logger.info("Done")
 
     if plot_last_layer:
-        n_neurons = [len(i[0]) for i in acts]
+        n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts]
     else:
-        n_neurons = [len(i[0]) for i in acts[:-1]]
+        n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts[:-1]]
     layers = len(n_neurons)
 
     nrows_ncols = (layers, max(n_neurons))
@@ -282,8 +282,10 @@ def plot_profile_2D_all(plotname, model, events,
     ims = []
     reg_plots = []
     for layer in range(layers):
-        for neuron in range(len(acts[layer][0])):
-            acts_neuron = acts[layer][:,neuron]
+        neurons_acts = acts[layer]
+        neurons_acts = neurons_acts.reshape(neurons_acts.shape[0], -1)
+        for neuron in range(len(neurons_acts[0])):
+            acts_neuron = neurons_acts[:,neuron]
             ax = grid_array[neuron][layer]
             extra_opts = {}
             if not (plot_last_layer and layer == layers-1):