From 5119a79dbb47dd4b3517d8cd9309f52175de50c1 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Wed, 22 Aug 2018 11:40:14 +0200
Subject: [PATCH] adjust max activation functions to use masking

---
 utils.py | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/utils.py b/utils.py
index 029f76e..ea98793 100644
--- a/utils.py
+++ b/utils.py
@@ -38,6 +38,8 @@ def create_random_event(ranges, mask_probs=None, mask_value=None):
     random_event = random_event.reshape(-1, len(random_event))
     # if given, mask values with a certain probability
     if mask_probs is not None:
+        if mask_value is None:
+            raise ValueError("Need to provide mask_value if random events should be masked")
         for var_index, mask_prob in enumerate(mask_probs):
             random_event[:,var_index][np.random.rand(len(random_event)) < mask_prob] = mask_value
     return random_event
@@ -122,7 +124,7 @@ def get_grad_function(model, layer, neuron):
        ],
        ignoreKwargs=["input_transform", "input_inverse_transform"],
 )
-def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **kwargs):
+def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, mask_probs=None, mask_value=None, **kwargs):
 
     gradient_function = get_grad_function(model, layer, neuron)
 
@@ -132,9 +134,15 @@ def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **k
     for i in range(ntries):
         if not (i%100):
             logger.info(i)
-        res = max_activation_wrt_input(gradient_function,
-                                       create_random_event(ranges),
-                                       **kwargs)
+        res = max_activation_wrt_input(
+            gradient_function,
+            create_random_event(
+                ranges,
+                mask_probs=mask_probs,
+                mask_value=mask_value
+            ),
+            **kwargs
+        )
         if res is not None:
             loss, event = res
             loss = np.array([loss])
@@ -195,7 +203,6 @@ class WeightedRobustScaler(RobustScaler):
             self.center_ = wqs[:,1]
             self.scale_ = wqs[:,2]-wqs[:,0]
             self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False)
-            print(self.scale_)
             return self
 
 
-- 
GitLab