diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py index 65b7fafb4ecb5137ac83c29c1280472be7e9923c..112c1bf5dcf5fee8523060cc193df6af7bef79ce 100755 --- a/scripts/plot_NN_2D.py +++ b/scripts/plot_NN_2D.py @@ -126,11 +126,15 @@ if args.mode.startswith("mean"): logscale=args.log, only_pixels=(not args.contour) ) else: + if hasattr(c, "get_input_list"): + transform_function = lambda inp : c.get_input_list(c.scaler.transform(inp)) + else: + transform_function = c.scaler.transform(inp) plot_NN_vs_var_2D_all( args.output_filename, means=means, model=c.model, - transform_function=c.scaler.transform, + transform_function=transform_function, varx_index=varx_index, vary_index=vary_index, xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2], diff --git a/utils.py b/utils.py index 316eba96295b66d5dd0e7ead7b51d4b4e09b92c0..daac2da92452de90ef18ae1d291f6c12ca3f28ba 100644 --- a/utils.py +++ b/utils.py @@ -15,14 +15,21 @@ logger.addHandler(logging.NullHandler()) def get_single_neuron_function(model, layer, neuron, scaler=None, input_transform=None): - f = K.function([model.input]+[K.learning_phase()], [model.layers[layer].output[:,neuron]]) + inp = model.input + if not isinstance(inp, list): + inp = [inp] + + f = K.function(inp+[K.learning_phase()], [model.layers[layer].output[:,neuron]]) def eval_single_neuron(x): + x_eval = x if scaler is not None: x_eval = scaler.transform(x) + if input_transform is not None: + x_eval = input_transform(x_eval) else: - x_eval = x - return f([x_eval])[0] + x_eval = [x_eval] + return f(x_eval)[0] return eval_single_neuron