Skip to content
Snippets Groups Projects
Commit 5119a79d authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adjust max activation functions to use masking

parent 4f02d77d
No related branches found
No related tags found
No related merge requests found
...@@ -38,6 +38,8 @@ def create_random_event(ranges, mask_probs=None, mask_value=None): ...@@ -38,6 +38,8 @@ def create_random_event(ranges, mask_probs=None, mask_value=None):
random_event = random_event.reshape(-1, len(random_event)) random_event = random_event.reshape(-1, len(random_event))
# if given, mask values with a certain probability # if given, mask values with a certain probability
if mask_probs is not None: if mask_probs is not None:
if mask_value is None:
raise ValueError("Need to provide mask_value if random events should be masked")
for var_index, mask_prob in enumerate(mask_probs): for var_index, mask_prob in enumerate(mask_probs):
random_event[:,var_index][np.random.rand(len(random_event)) < mask_prob] = mask_value random_event[:,var_index][np.random.rand(len(random_event)) < mask_prob] = mask_value
return random_event return random_event
...@@ -122,7 +124,7 @@ def get_grad_function(model, layer, neuron): ...@@ -122,7 +124,7 @@ def get_grad_function(model, layer, neuron):
], ],
ignoreKwargs=["input_transform", "input_inverse_transform"], ignoreKwargs=["input_transform", "input_inverse_transform"],
) )
def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **kwargs): def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, mask_probs=None, mask_value=None, **kwargs):
gradient_function = get_grad_function(model, layer, neuron) gradient_function = get_grad_function(model, layer, neuron)
...@@ -132,9 +134,15 @@ def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **k ...@@ -132,9 +134,15 @@ def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **k
for i in range(ntries): for i in range(ntries):
if not (i%100): if not (i%100):
logger.info(i) logger.info(i)
res = max_activation_wrt_input(gradient_function, res = max_activation_wrt_input(
create_random_event(ranges), gradient_function,
**kwargs) create_random_event(
ranges,
mask_probs=mask_probs,
mask_value=mask_value
),
**kwargs
)
if res is not None: if res is not None:
loss, event = res loss, event = res
loss = np.array([loss]) loss = np.array([loss])
...@@ -195,7 +203,6 @@ class WeightedRobustScaler(RobustScaler): ...@@ -195,7 +203,6 @@ class WeightedRobustScaler(RobustScaler):
self.center_ = wqs[:,1] self.center_ = wqs[:,1]
self.scale_ = wqs[:,2]-wqs[:,0] self.scale_ = wqs[:,2]-wqs[:,0]
self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False)
print(self.scale_)
return self return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment