diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py index 108c35d53d6e18cf59956b6a11cefde384b51b00..58dbc7d92413f95ae25a0bb2d738aa5705580d97 100755 --- a/scripts/plot_NN_2D.py +++ b/scripts/plot_NN_2D.py @@ -11,7 +11,7 @@ import ROOT ROOT.gROOT.SetBatch() ROOT.PyConfig.IgnoreCommandLineOptions = True -from KerasROOTClassification import ClassificationProject +from KerasROOTClassification import ClassificationProject, load_from_dir from KerasROOTClassification.plotting import ( get_mean_event, plot_NN_vs_var_2D, @@ -54,7 +54,7 @@ if args.all_neurons and (not args.mode.startswith("mean")): if args.verbose: logging.getLogger().setLevel(logging.DEBUG) -c = ClassificationProject(args.project_dir) +c = load_from_dir(args.project_dir) plot_vs_activation = (args.vary == "activation") @@ -104,13 +104,22 @@ if args.mode.startswith("mean"): elif args.mode == "mean_bkg": means = get_mean_event(c.x_test, c.y_test, 0) + if hasattr(c, "get_input_list"): + input_transform = c.get_input_list + else: + input_transform = None + if not args.all_neurons: plot_NN_vs_var_2D( args.output_filename, means=means, varx_index=varx_index, vary_index=vary_index, - scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler), + scorefun=get_single_neuron_function( + c.model, layer, neuron, + scaler=c.scaler, + input_transform=input_transform + ), xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2], ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2], varx_label=varx_label, vary_label=vary_label, diff --git a/utils.py b/utils.py index 05fffea1a89c2042ee9a92f38cc6c545647232c9..316eba96295b66d5dd0e7ead7b51d4b4e09b92c0 100644 --- a/utils.py +++ b/utils.py @@ -13,7 +13,7 @@ from meme import cache logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -def get_single_neuron_function(model, layer, neuron, scaler=None): +def get_single_neuron_function(model, layer, neuron, scaler=None, input_transform=None): f = K.function([model.input]+[K.learning_phase()], [model.layers[layer].output[:,neuron]])