From 5a494d99f0229b1f46db969c9d7d368ceabe88e0 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 20 Aug 2018 10:21:10 +0200
Subject: [PATCH] not working yet

---
 scripts/plot_NN_2D.py | 15 ++++++++++++---
 utils.py              |  2 +-
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 108c35d..58dbc7d 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -11,7 +11,7 @@ import ROOT
 ROOT.gROOT.SetBatch()
 ROOT.PyConfig.IgnoreCommandLineOptions = True
 
-from KerasROOTClassification import ClassificationProject
+from KerasROOTClassification import ClassificationProject, load_from_dir
 from KerasROOTClassification.plotting import (
     get_mean_event,
     plot_NN_vs_var_2D,
@@ -54,7 +54,7 @@ if args.all_neurons and (not args.mode.startswith("mean")):
 if args.verbose:
     logging.getLogger().setLevel(logging.DEBUG)
 
-c = ClassificationProject(args.project_dir)
+c = load_from_dir(args.project_dir)
 
 plot_vs_activation = (args.vary == "activation")
 
@@ -104,13 +104,22 @@ if args.mode.startswith("mean"):
     elif args.mode == "mean_bkg":
         means = get_mean_event(c.x_test, c.y_test, 0)
 
+    if hasattr(c, "get_input_list"):
+        input_transform = c.get_input_list
+    else:
+        input_transform = None
+
     if not args.all_neurons:
         plot_NN_vs_var_2D(
             args.output_filename,
             means=means,
             varx_index=varx_index,
             vary_index=vary_index,
-            scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
+            scorefun=get_single_neuron_function(
+                c.model, layer, neuron,
+                scaler=c.scaler,
+                input_transform=input_transform
+            ),
             xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
             ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
             varx_label=varx_label, vary_label=vary_label,
diff --git a/utils.py b/utils.py
index 05fffea..316eba9 100644
--- a/utils.py
+++ b/utils.py
@@ -13,7 +13,7 @@ from meme import cache
 logger = logging.getLogger(__name__)
 logger.addHandler(logging.NullHandler())
 
-def get_single_neuron_function(model, layer, neuron, scaler=None):
+def get_single_neuron_function(model, layer, neuron, scaler=None, input_transform=None):
 
     f = K.function([model.input]+[K.learning_phase()], [model.layers[layer].output[:,neuron]])
 
-- 
GitLab