diff --git a/toolkit.py b/toolkit.py
index b31a15a273293af6c5cc516ea33abe400b0524e1..c941dd0decfe64feb0093be5178e0fb0761642e1 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -19,6 +19,7 @@ import math
 import glob
 import shutil
 import gc
+import random
 
 import logging
 logger = logging.getLogger("KerasROOTClassification")
@@ -33,7 +34,7 @@ from sklearn.externals import joblib
 from sklearn.metrics import roc_curve, auc
 from sklearn.utils.extmath import stable_cumsum
 from keras.models import Sequential, Model, model_from_json
-from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate
+from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN
 from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
 from keras.optimizers import SGD
 import keras.optimizers
@@ -744,7 +745,7 @@ class ClassificationProject(object):
 
         # plot model
         with open(os.path.join(self.project_dir, "model.svg"), "wb") as of:
-            of.write(model_to_dot(self._model).create("dot", format="svg"))
+            of.write(model_to_dot(self._model, show_shapes=True).create("dot", format="svg"))
 
 
     @property
@@ -1626,6 +1627,7 @@ class ClassificationProjectRNN(ClassificationProject):
                         recurrent_field_names=None,
                         rnn_layer_nodes=32,
                         mask_value=-999,
+                        recurrent_unit_type="GRU",
                         **kwargs):
         """
         recurrent_field_names example:
@@ -1644,6 +1646,7 @@ class ClassificationProjectRNN(ClassificationProject):
             self.recurrent_field_names = []
         self.rnn_layer_nodes = rnn_layer_nodes
         self.mask_value = mask_value
+        self.recurrent_unit_type = recurrent_unit_type
 
         # convert to  of indices
         self.recurrent_field_idx = []
@@ -1684,7 +1687,13 @@ class ClassificationProjectRNN(ClassificationProject):
             for field_idx in self.recurrent_field_idx:
                 chan_inp = Input(field_idx.shape[1:])
                 channel = Masking(mask_value=self.mask_value)(chan_inp)
-                channel = GRU(self.rnn_layer_nodes)(channel)
+                if self.recurrent_unit_type == "GRU":
+                    channel = GRU(self.rnn_layer_nodes)(channel)
+                elif self.recurrent_unit_type == "SimpleRNN":
+                    channel = SimpleRNN(self.rnn_layer_nodes)(channel)
+                else:
+                    raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type))
+                logger.info("Added {} unit".format(self.recurrent_unit_type))
                 # TODO: configure dropout for recurrent layers
                 #channel = Dropout(0.3)(channel)
                 rnn_inputs.append(chan_inp)
@@ -1731,6 +1740,37 @@ class ClassificationProjectRNN(ClassificationProject):
         self.checkpoint_model()
 
 
+    def clean_mask(self, x):
+        """
+        Mask recurrent fields such that once a masked value occurs,
+        all values corresponding to the same and following objects are
+        masked as well. Works in place.
+        """
+        for recurrent_field_idx in self.recurrent_field_idx:
+            for evt in x:
+                masked = False
+                for line_idx in recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:]):
+                    if (evt[line_idx] == self.mask_value).any():
+                        masked=True
+                    if masked:
+                        evt[line_idx] = self.mask_value
+
+
+    def mask_uniform(self, x):
+        """
+        Mask recurrent fields with a random (uniform) number of objects. Works in place.
+        """
+        for recurrent_field_idx in self.recurrent_field_idx:
+            for evt in x:
+                masked = False
+                nobj = int(random.random()*(recurrent_field_idx.shape[1]+1))
+                for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])):
+                    if obj_number == nobj:
+                        masked=True
+                    if masked:
+                        evt[line_idx] = self.mask_value
+
+
     def get_input_list(self, x):
         "Format the input starting from flat ntuple"
         x_input = []
diff --git a/utils.py b/utils.py
index f3323c42a18772bd3b4b68004af752ef7610b157..48f58f60b909b912d8b47038fd732c1ce339a05c 100644
--- a/utils.py
+++ b/utils.py
@@ -33,22 +33,31 @@ def get_single_neuron_function(model, layer, neuron, input_transform=None):
     return eval_single_neuron
 
 
-def create_random_event(ranges):
+def create_random_event(ranges, mask_probs=None, mask_value=None):
     random_event = np.array([p[0]+(p[1]-p[0])*np.random.rand() for p in ranges])
     random_event = random_event.reshape(-1, len(random_event))
+    # if given, mask values with a certain probability
+    if mask_probs is not None:
+        if mask_value is None:
+            raise ValueError("Need to provide mask_value if random events should be masked")
+        for var_index, mask_prob in enumerate(mask_probs):
+            random_event[:,var_index][np.random.rand(len(random_event)) < mask_prob] = mask_value
     return random_event
 
 
 def get_ranges(x, quantiles, weights, mask_value=None, filter_index=None):
     "Get ranges for plotting or random event generation based on quantiles"
     ranges = []
+    mask_probs = []
     for var_index in range(x.shape[1]):
         if (filter_index is not None) and (var_index != filter_index):
             continue
         x_var = x[:,var_index]
         not_masked = np.where(x_var != mask_value)[0]
+        masked = np.where(x_var == mask_value)[0]
         ranges.append(weighted_quantile(x_var[not_masked], quantiles, sample_weight=weights[not_masked]))
-    return ranges
+        mask_probs.append(float(len(masked))/float(len(x_var)))
+    return ranges, mask_probs
 
 
 def max_activation_wrt_input(gradient_function, random_event, threshold=None, maxthreshold=None, maxit=100, step=1, const_indices=[],
@@ -115,7 +124,7 @@ def get_grad_function(model, layer, neuron):
        ],
        ignoreKwargs=["input_transform", "input_inverse_transform"],
 )
-def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **kwargs):
+def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, mask_probs=None, mask_value=None, **kwargs):
 
     gradient_function = get_grad_function(model, layer, neuron)
 
@@ -125,9 +134,15 @@ def get_max_activation_events(model, ranges, ntries, layer, neuron, seed=42, **k
     for i in range(ntries):
         if not (i%100):
             logger.info(i)
-        res = max_activation_wrt_input(gradient_function,
-                                       create_random_event(ranges),
-                                       **kwargs)
+        res = max_activation_wrt_input(
+            gradient_function,
+            create_random_event(
+                ranges,
+                mask_probs=mask_probs,
+                mask_value=mask_value
+            ),
+            **kwargs
+        )
         if res is not None:
             loss, event = res
             loss = np.array([loss])
@@ -188,7 +203,6 @@ class WeightedRobustScaler(RobustScaler):
             self.center_ = wqs[:,1]
             self.scale_ = wqs[:,2]-wqs[:,0]
             self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False)
-            print(self.scale_)
             return self
 
 
@@ -202,6 +216,15 @@ class WeightedRobustScaler(RobustScaler):
             return super(WeightedRobustScaler, self).transform(X)
 
 
+    def inverse_transform(self, X):
+        if np.isnan(X).any():
+            X *= self.scale_
+            X += self.center_
+            return X
+        else:
+            return super(WeightedRobustScaler, self).inverse_transform(X)
+
+
 def poisson_asimov_significance(s, ds, b, db):
     "see `<http://www.pp.rhul.ac.uk/~cowan/stat/medsig/medsigNote.pdf>`_)"
     db = np.sqrt(db**2+ds**2)