diff --git a/toolkit.py b/toolkit.py
index 57ef3f3c3b0d7466f4f76b176d5794b071909b8d..1fcf30cdc88aa3accd18da21947a5a0f03c97af3 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1361,6 +1361,7 @@ class ClassificationProjectRNN(ClassificationProject):
 
     def __init__(self, name,
                  recurrent_field_names=None,
+                 rnn_layer_nodes=32,
                  mask_value=-999,
                  **kwargs):
         """
@@ -1376,6 +1377,7 @@ class ClassificationProjectRNN(ClassificationProject):
         self.recurrent_field_names = recurrent_field_names
         if self.recurrent_field_names is None:
             self.recurrent_field_names = []
+        self.rnn_layer_nodes = rnn_layer_nodes
         self.mask_value = mask_value
 
         # convert to  of indices
@@ -1425,7 +1427,7 @@ class ClassificationProjectRNN(ClassificationProject):
             for field_idx in self.recurrent_field_idx:
                 chan_inp = Input(field_idx.shape[1:])
                 channel = Masking(mask_value=self.mask_value)(chan_inp)
-                channel = GRU(32)(channel)
+                channel = GRU(self.rnn_layer_nodes)(channel)
                 # TODO: configure dropout for recurrent layers
                 #channel = Dropout(0.3)(channel)
                 rnn_inputs.append(chan_inp)
@@ -1443,11 +1445,15 @@ class ClassificationProjectRNN(ClassificationProject):
             combined = Dense(1, activation=self.activation_function_output)(combined)
             self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined)
             self._compile_or_load_model()
-        else:
-            return self._model
+        return self._model
 
 
     def train(self, epochs=10):
+        self.load()
+
+        for branch_index, branch in enumerate(self.fields):
+            self.plot_input(branch_index)
+
         try:
             self.shuffle_training_data() # needed here too, in order to get correct validation data
             self.is_training = True