diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 65b7fafb4ecb5137ac83c29c1280472be7e9923c..112c1bf5dcf5fee8523060cc193df6af7bef79ce 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -126,11 +126,15 @@ if args.mode.startswith("mean"):
             logscale=args.log, only_pixels=(not args.contour)
         )
     else:
+        if hasattr(c, "get_input_list"):
+            transform_function = lambda inp : c.get_input_list(c.scaler.transform(inp))
+        else:
+            transform_function = c.scaler.transform(inp)
         plot_NN_vs_var_2D_all(
             args.output_filename,
             means=means,
             model=c.model,
-            transform_function=c.scaler.transform,
+            transform_function=transform_function,
             varx_index=varx_index,
             vary_index=vary_index,
             xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
diff --git a/utils.py b/utils.py
index 316eba96295b66d5dd0e7ead7b51d4b4e09b92c0..daac2da92452de90ef18ae1d291f6c12ca3f28ba 100644
--- a/utils.py
+++ b/utils.py
@@ -15,14 +15,21 @@ logger.addHandler(logging.NullHandler())
 
 def get_single_neuron_function(model, layer, neuron, scaler=None, input_transform=None):
 
-    f = K.function([model.input]+[K.learning_phase()], [model.layers[layer].output[:,neuron]])
+    inp = model.input
+    if not isinstance(inp, list):
+        inp = [inp]
+
+    f = K.function(inp+[K.learning_phase()], [model.layers[layer].output[:,neuron]])
 
     def eval_single_neuron(x):
+        x_eval = x
         if scaler is not None:
             x_eval = scaler.transform(x)
+        if input_transform is not None:
+            x_eval = input_transform(x_eval)
         else:
-            x_eval = x
-        return f([x_eval])[0]
+            x_eval = [x_eval]
+        return f(x_eval)[0]
 
     return eval_single_neuron