diff --git a/plotting.py b/plotting.py
index 4d071fd805e8e047d8444a705c0ebcb3c268d7ba..ecad66023599bc44be1917d0476c44dabe958378 100644
--- a/plotting.py
+++ b/plotting.py
@@ -6,8 +6,11 @@ import math
 import matplotlib.pyplot as plt
 import matplotlib.colors
 from matplotlib.ticker import LogFormatter
+from mpl_toolkits.axes_grid1 import ImageGrid
 import numpy as np
 
+from .keras-visualize-activations.read_activations import get_activations
+
 import meme
 
 """
@@ -97,6 +100,66 @@ def plot_NN_vs_var_2D(plotname, means,
     fig.savefig(plotname)
 
 
+def plot_NN_vs_var_2D_all(plotname, model, means,
+                          var1_index, var1_range,
+                          var2_index, var2_range,
+                          transform_function=None,
+                          var1_label=None,
+                          var2_label=None):
+
+    "Similar to plot_NN_vs_var_2D, but creates a grid of plots for all neurons."
+
+    # var1 = "lep1Phi"
+    # var2 = "met_Phi"
+    # # var1_vals = np.arange(-3.15,3.15,0.1)
+    # # var2_vals = np.arange(-3.15,3.15,0.1)
+    var1_vals = np.arange(*var1_range)
+    var2_vals = np.arange(*var2_range)
+
+    # create the events for which we want to fetch the activations
+    events = np.tile(means, len(var1_vals)*len(var2_vals)).reshape(len(var2_vals), len(var1_vals), -1)
+    for i, y in enumerate(var2_vals):
+        for j, x in enumerate(var1_vals):
+            events[i][j][index1] = x
+            events[i][j][index2] = y
+
+    # convert back into 1d array
+    events = events.reshape(-1, len(means))
+
+    # transform
+    if transform_function is not None:
+        #events = c.scaler.transform(events)
+        events = transform_function(events)
+
+    acts = get_activations(model, events, print_shape_only=True)
+
+    aspect = (var1_vals[-1]-var1_vals[0])/(var2_vals[-1]-var2_vals[0])
+
+    n_neurons = [len(i[0]) for i in acts]
+    layers = len(n_neurons)
+
+    nrows_ncols = (layers, max(n_neurons))
+    fig = plt.figure(1, figsize=nrows_ncols)
+    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0, label_mode="1")
+    grid = np.array(grid)
+    grid = grid.reshape(*nrows_ncols[::-1])
+
+    for layer in range(layers):
+        for neuron in range(len(acts[layer][0])):
+            acts_neuron = acts[layer][:,neuron]
+            acts_neuron = acts_neuron.reshape(len(var2_vals), len(var1_vals))
+            ax = grid[neuron][layer]
+            ax.imshow(acts_neuron, origin="lower", extent=[var1_vals[0], var1_vals[-1], var2_vals[0], var2_vals[-1]], aspect=aspect, cmap="jet")
+            if var1_label is not None:
+                ax.set_xlabel(var1_label)
+            if var2_label is not None:
+                ax.set_ylabel(var2_label)
+            ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes)
+
+    fig.savefig(plotname, bbox_inches='tight')
+
+
+
 
 if __name__ == "__main__":
 
@@ -128,3 +191,10 @@ if __name__ == "__main__":
                       var2_index=c.branches.index("mt"), var2_range=(0, 500, 10),
                       var1_label="met [GeV]", var2_label="mt [GeV]")
 
+
+    plot_NN_vs_var_2D_all("mt_vs_met_all.pdf", means=mean_signal,
+                          model=c.model, transform_function=c.scaler.transform,
+                          var1_index=c.branches.index("met"), var1_range=(0, 1000, 10),
+                          var2_index=c.branches.index("mt"), var2_range=(0, 500, 10),
+                          var1_label="met [GeV]", var2_label="mt [GeV]")
+