From c496862ff38b2904e07899b25273be5c82621f09 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Thu, 23 Aug 2018 09:40:04 +0200
Subject: [PATCH] Adding support for SimpleRNN

---
 toolkit.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 5b33e00..3498ab3 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -34,7 +34,7 @@ from sklearn.externals import joblib
 from sklearn.metrics import roc_curve, auc
 from sklearn.utils.extmath import stable_cumsum
 from keras.models import Sequential, Model, model_from_json
-from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate
+from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN
 from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
 from keras.optimizers import SGD
 import keras.optimizers
@@ -716,7 +716,7 @@ class ClassificationProject(object):
 
         # plot model
         with open(os.path.join(self.project_dir, "model.svg"), "wb") as of:
-            of.write(model_to_dot(self._model).create("dot", format="svg"))
+            of.write(model_to_dot(self._model, show_shapes=True).create("dot", format="svg"))
 
 
     @property
@@ -1598,6 +1598,7 @@ class ClassificationProjectRNN(ClassificationProject):
                         recurrent_field_names=None,
                         rnn_layer_nodes=32,
                         mask_value=-999,
+                        recurrent_unit_type="GRU",
                         **kwargs):
         """
         recurrent_field_names example:
@@ -1616,6 +1617,7 @@ class ClassificationProjectRNN(ClassificationProject):
             self.recurrent_field_names = []
         self.rnn_layer_nodes = rnn_layer_nodes
         self.mask_value = mask_value
+        self.recurrent_unit_type = recurrent_unit_type
 
         # convert to  of indices
         self.recurrent_field_idx = []
@@ -1664,7 +1666,13 @@ class ClassificationProjectRNN(ClassificationProject):
             for field_idx in self.recurrent_field_idx:
                 chan_inp = Input(field_idx.shape[1:])
                 channel = Masking(mask_value=self.mask_value)(chan_inp)
-                channel = GRU(self.rnn_layer_nodes)(channel)
+                if self.recurrent_unit_type == "GRU":
+                    channel = GRU(self.rnn_layer_nodes)(channel)
+                elif self.recurrent_unit_type == "SimpleRNN":
+                    channel = SimpleRNN(self.rnn_layer_nodes)(channel)
+                else:
+                    raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type))
+                logger.info("Added {} unit".format(self.recurrent_unit_type))
                 # TODO: configure dropout for recurrent layers
                 #channel = Dropout(0.3)(channel)
                 rnn_inputs.append(chan_inp)
-- 
GitLab