diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 680d657769e9e030e4aa184bf6ee9cb33d895345..0a0be891f18635147eb7c7d563e819ed35c29bfe 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -220,7 +220,17 @@ elif args.mode.startswith("hist"):
         # ranges in which to sample the random events
         x_test_scaled = c.transform(c.x_test)
         ranges = get_ranges(x_test_scaled, [0.01, 0.99], total_weights, mask_value=mask_value)
-        losses, events = get_max_activation_events(c.model, ranges, ntries=args.ntries_actmax, step=args.step_size, layer=layer, neuron=neuron, threshold=args.threshold)
+        kwargs = {}
+        if hasattr(c, "get_input_list"):
+            kwargs["input_transform"] = c.get_input_list
+            kwargs["input_inverse_transform"] = c.get_input_flat
+        losses, events = get_max_activation_events(c.model, ranges,
+                                                   ntries=args.ntries_actmax,
+                                                   step=args.step_size,
+                                                   layer=layer,
+                                                   neuron=neuron,
+                                                   threshold=args.threshold,
+                                                   **kwargs)
         events = c.inverse_transform(events)
         valsx = events[:,varx_index]
         if not plot_vs_activation:
diff --git a/utils.py b/utils.py
index 39ccc4865990a98c6afcf75fd086915849049f13..a71f251815679b7e4ebf2e76de4acc3acc8ef449 100644
--- a/utils.py
+++ b/utils.py
@@ -39,26 +39,44 @@ def create_random_event(ranges):
     return random_event
 
 
-def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[]):
-    for i in range(maxit):
-        loss_value, grads_value = gradient_function([random_event])
-        for const_index in const_indices:
-            grads_value[0][const_index] = 0
-        if threshold is not None:
-            if loss_value > threshold and (maxthreshold is None or loss_value < maxthreshold):
-                # found an event within the thresholds
-                return loss_value, random_event
-            elif (maxthreshold is not None and loss_value > maxthreshold):
-                random_event -= grads_value*step
-            else:
-                random_event += grads_value*step
+def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[],
+                             input_transform=None, input_inverse_transform=None):
+    if input_transform is not None:
+        random_event = input_transform(random_event)
+    if not isinstance(random_event, list):
+        random_event = [random_event]
+
+    def iterate(random_event):
+        for i in range(maxit):
+            grads_out = gradient_function(random_event)
+            loss_value = grads_out[0][0]
+            grads_values = grads_out[1:]
+            # follow gradient for all inputs
+            for i, (grads_value, input_event) in enumerate(zip(grads_values, random_event)):
+                for const_index in const_indices:
+                    grads_value[0][const_index] = 0
+                if threshold is not None:
+                    if loss_value > threshold and (maxthreshold is None or loss_value < maxthreshold):
+                        # found an event within the thresholds
+                        return loss_value, random_event
+                    elif (maxthreshold is not None and loss_value > maxthreshold):
+                        random_event[i] -= grads_value*step
+                    else:
+                        random_event[i] += grads_value*step
+                else:
+                    random_event[i] += grads_value*step
         else:
-            random_event += grads_value*step
-    else:
-        if threshold is not None:
-            # no event found
-            return None
-    # if no threshold requested, always return last status
+            if threshold is not None:
+                # no event found for the given threshold
+                return None, None
+        # otherwise return last status
+        return loss_value, random_event
+
+    loss_value, random_event = iterate(random_event)
+    if input_inverse_transform is not None and random_event is not None:
+        random_event = input_inverse_transform(random_event)
+    elif random_event is None:
+        return None
     return loss_value, random_event
 
 
@@ -66,12 +84,16 @@ def get_grad_function(model, layer, neuron):
 
     loss = model.layers[layer].output[:,neuron]
 
-    grads = K.gradients(loss, model.input)[0]
+    grads = K.gradients(loss, model.input)
 
     # trick from https://blog.keras.io/how-convolutional-neural-networks-see-the-world.html
-    grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-5)
+    norm_grads = [grad/(K.sqrt(K.mean(K.square(grad))) + 1e-5) for grad in grads]
+
+    inp = model.input
+    if not isinstance(inp, list):
+        inp = [inp]
 
-    return K.function([model.input], [loss, grads])
+    return K.function(inp, [loss]+norm_grads)
 
 
 @cache(useJSON=True,
@@ -79,6 +101,7 @@ def get_grad_function(model, layer, neuron):
            lambda model: [hash(i.tostring()) for i in model.get_weights()],
            lambda ranges: [hash(i.tostring()) for i in ranges],
        ],
+       ignoreKwargs=["input_transform", "input_inverse_transform"],
 )
 def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **kwargs):
 
@@ -90,9 +113,12 @@ def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **k
     for i in range(ntries):
         if not (i%100):
             logger.info(i)
-        res = max_activation_wrt_input(gradient_function, create_random_event(ranges), **kwargs)
+        res = max_activation_wrt_input(gradient_function,
+                                       create_random_event(ranges),
+                                       **kwargs)
         if res is not None:
             loss, event = res
+            loss = np.array([loss])
         else:
             continue
         if events is None: