From 4405753d23aa24e40217d683f4ae5ac048396f0e Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 21 Aug 2018 11:27:50 +0200
Subject: [PATCH] fixing input_transform

---
 plotting.py           | 2 +-
 scripts/plot_NN_2D.py | 5 ++---
 toolkit.py            | 2 +-
 3 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/plotting.py b/plotting.py
index 749d2a9..db4df95 100644
--- a/plotting.py
+++ b/plotting.py
@@ -36,7 +36,7 @@ def get_mean_event(x, y, class_label, mask_value=None):
     means = []
     for var_index in range(x.shape[1]):
         if mask_value is not None:
-            masked_values = np.where(x[:,var_index] == mask_value)[0]
+            masked_values = np.where(x[:,var_index] != mask_value)[0]
             x = x[masked_values]
             y = y[masked_values]
         means.append(np.mean(x[y==class_label][:,var_index]))
diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 460934d..8a75af3 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -115,9 +115,9 @@ if args.mode.startswith("mean"):
     print(means)
 
     if hasattr(c, "get_input_list"):
-        input_transform = c.get_input_list
+        input_transform = lambda x : c.get_input_list(c.transform(x))
     else:
-        input_transform = None
+        input_transform = c.transform
 
     if not args.all_neurons:
         plot_NN_vs_var_2D(
@@ -127,7 +127,6 @@ if args.mode.startswith("mean"):
             vary_index=vary_index,
             scorefun=get_single_neuron_function(
                 c.model, layer, neuron,
-                scaler=c.scaler,
                 input_transform=input_transform
             ),
             xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
diff --git a/toolkit.py b/toolkit.py
index 8779db4..834d977 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1770,7 +1770,7 @@ class ClassificationProjectRNN(ClassificationProject):
 
     def _batch_transform(self, x, fn, batch_size):
         "Transform array in batches, temporarily setting mask_values to nan"
-        transformed = np.empty(len(x))
+        transformed = np.empty(x.shape, dtype=x.dtype)
         for start in range(0, len(x), batch_size):
             stop = start+batch_size
             x_batch = np.array(x[start:stop]) # copy
-- 
GitLab