diff --git a/toolkit.py b/toolkit.py
index 09d7bde174aa35681b849786dd64e1fe0623cd60..37c77ed072e73b9229e523798d55c82e41ddfe25 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-__all__ = ["ClassificationProject", "ClassificationProjectDataFrame"]
+__all__ = ["ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"]
 
 from sys import version_info
 
@@ -578,9 +578,12 @@ class ClassificationProject(object):
     def _transform_data(self):
         if not self.data_transformed:
             # todo: what to do about the outliers? Where do they come from?
-            logger.debug("training data before transformation: {}".format(self.x_train))
-            logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
-            logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
+            if logger.level <= logging.DEBUG:
+                logger.debug("training data before transformation: {}".format(self.x_train))
+                logger.debug("minimum values: {}".format([np.min(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
+                                                          for i in range(self.x_train.shape[1])]))
+                logger.debug("maximum values: {}".format([np.max(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
+                                                          for i in range(self.x_train.shape[1])]))
             orig_copy_setting = self.scaler.copy
             self.scaler.copy = False
             self.x_train = self.scaler.transform(self.x_train)
@@ -1353,32 +1356,60 @@ class ClassificationProjectRNN(ClassificationProject):
     A little wrapper to use recurrent units for things like jet collections
     """
 
-    def __init__(self,
-                 recurrent_fields=None,
+    def __init__(self, name,
+                 recurrent_field_names=None,
                  mask_value=-999,
                  **kwargs):
         """
-        recurrent_fields example:
+        recurrent_field_names example:
         [["jet1Pt", "jet1Eta", "jet1Phi"],
          ["jet2Pt", "jet2Eta", "jet2Phi"],
          ["jet3Pt", "jet3Eta", "jet3Phi"]],
         [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"],
          ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]],
         """
-        self.recurrent_fields = recurrent_fields
-        if self.recurrent_fields is None:
-            self.recurrent_fields = []
-        for i, recurrent_field in enumerate(self.recurrent_fields):
-            self.recurrent_fields[i] = np.array(recurrent_field)
-            if self.recurrent_fields[i].dtype == np.object:
+        super(ClassificationProjectRNN, self).__init__(name, **kwargs)
+
+        self.recurrent_field_names = recurrent_field_names
+        if self.recurrent_field_names is None:
+            self.recurrent_field_names = []
+        self.mask_value = mask_value
+
+        # convert to  of indices
+        self.recurrent_field_idx = []
+        for field_name_list in self.recurrent_field_names:
+            field_names = np.array([field_name_list])
+            if field_names.dtype == np.object:
                 raise ValueError(
                     "Invalid entry for recurrent fields: {} - "
                     "please ensure that the length for all elements in the list is equal"
-                    .format(recurrent_field)
+                    .format(field_names)
                 )
-        self.mask_value = mask_value
-        super(ClassificationProjectRNN, self).__init__()
-        self.flat_fields = [field for field in self.fields if not field in self.recurrent_fields]
+            field_idx = (
+                np.array([self.fields.index(field_name)
+                          for field_name in field_names.reshape(-1)])
+                .reshape(field_names.shape)
+            )
+            self.recurrent_field_idx.append(field_idx)
+        self.flat_fields = []
+        for field in self.fields:
+            if any(self.fields.index(field) in field_idx.reshape(-1) for field_idx in self.recurrent_field_idx):
+                continue
+            self.flat_fields.append(field)
+
+        if self.scaler_type != "WeightedRobustScaler":
+            raise NotImplementedError(
+                "Invalid scaler '{}' - only WeightedRobustScaler is currently supported for RNN"
+                .format(self.scaler_type)
+            )
+
+
+    def _transform_data(self):
+        self.x_train[self.x_train == self.mask_value] = np.nan
+        self.x_test[self.x_test == self.mask_value] = np.nan
+        super(ClassificationProjectRNN, self)._transform_data()
+        self.x_train[np.isnan(self.x_train)] = self.mask_value
+        self.x_test[np.isnan(self.x_test)] = self.mask_value
 
 
     @property
@@ -1386,32 +1417,29 @@ class ClassificationProjectRNN(ClassificationProject):
         pass
 
 
+    def get_input_list(self, x):
+        "Format the input starting from flat ntuple"
+        x_input = []
+        x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
+        x_input.append(x_flat)
+        for field_idx in self.recurrent_field_idx:
+            x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape)
+            x_input.append(x_recurrent)
+        return x_input
+
+
     def yield_batch(self):
         x_train, y_train, w_train = self.training_data
         while True:
             shuffled_idx = np.random.permutation(len(x_train))
             for start in range(0, len(shuffled_idx), int(self.batch_size)):
                 x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
-                x_flat = x_batch[:,self.flat_fields]
-                x_input = []
-                x_input.append(x_flat)
-                for recurrent_field in self.recurrent_fields:
-                    x_recurrent = x_batch[:,recurrent_field.reshape(-1)].reshape(-1, *recurrent_field.shape)
-                    x_input.append(x_recurrent)
+                y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
+                w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
+                x_input = self.get_input_list(x_batch)
                 yield (x_input,
                        y_train[shuffled_idx[start:start+int(self.batch_size)]],
-                       w_train[shuffled_idx[start:start+int(self.batch_size)]]*self.balanced_class_weight[class_label])
-            # # shuffle the entries for this class label
-            # rn_state = np.random.get_state()
-            # x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
-            # np.random.set_state(rn_state)
-            # w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
-            # # yield them batch wise
-            # for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
-            #     yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
-            #            y_train[y_train==class_label][start:start+int(self.batch_size/2)],
-            #            w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label])
-            # restart
+                       w_batch*np.array(self.class_weight)[y_batch.astype(int)])
 
 
 if __name__ == "__main__":
diff --git a/utils.py b/utils.py
index 1da306ca79f3b350d9e12708c942254ddebea067..5f6145ffb8a59008551c1945dfeaa717b687292a 100644
--- a/utils.py
+++ b/utils.py
@@ -134,13 +134,25 @@ def weighted_quantile(values, quantiles, sample_weight=None, values_sorted=False
 class WeightedRobustScaler(RobustScaler):
 
     def fit(self, X, y=None, weights=None):
-        RobustScaler.fit(self, X, y)
+        if not np.isnan(X).any():
+            # these checks don't work for nan values
+            super(WeightedRobustScaler, self).fit(X, y)
         if weights is None:
             return self
         else:
-            wqs = np.array([weighted_quantile(X[:,i], [0.25, 0.5, 0.75], sample_weight=weights) for i in range(X.shape[1])])
+            wqs = np.array([weighted_quantile(X[:,i][~np.isnan(X[:,i])], [0.25, 0.5, 0.75], sample_weight=weights) for i in range(X.shape[1])])
             self.center_ = wqs[:,1]
             self.scale_ = wqs[:,2]-wqs[:,0]
             self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False)
+            print(self.scale_)
             return self
 
+
+    def transform(self, X):
+        if np.isnan(X).any():
+            # we'd like to ignore nan values, so lets calculate without further checks
+            X -= self.center_
+            X /= self.scale_
+            return X
+        else:
+            return super(WeightedRobustScaler, self).transform(X)