From c496862ff38b2904e07899b25273be5c82621f09 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 23 Aug 2018 09:40:04 +0200 Subject: [PATCH] Adding support for SimpleRNN --- toolkit.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/toolkit.py b/toolkit.py index 5b33e00..3498ab3 100755 --- a/toolkit.py +++ b/toolkit.py @@ -34,7 +34,7 @@ from sklearn.externals import joblib from sklearn.metrics import roc_curve, auc from sklearn.utils.extmath import stable_cumsum from keras.models import Sequential, Model, model_from_json -from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate +from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard from keras.optimizers import SGD import keras.optimizers @@ -716,7 +716,7 @@ class ClassificationProject(object): # plot model with open(os.path.join(self.project_dir, "model.svg"), "wb") as of: - of.write(model_to_dot(self._model).create("dot", format="svg")) + of.write(model_to_dot(self._model, show_shapes=True).create("dot", format="svg")) @property @@ -1598,6 +1598,7 @@ class ClassificationProjectRNN(ClassificationProject): recurrent_field_names=None, rnn_layer_nodes=32, mask_value=-999, + recurrent_unit_type="GRU", **kwargs): """ recurrent_field_names example: @@ -1616,6 +1617,7 @@ class ClassificationProjectRNN(ClassificationProject): self.recurrent_field_names = [] self.rnn_layer_nodes = rnn_layer_nodes self.mask_value = mask_value + self.recurrent_unit_type = recurrent_unit_type # convert to of indices self.recurrent_field_idx = [] @@ -1664,7 +1666,13 @@ class ClassificationProjectRNN(ClassificationProject): for field_idx in self.recurrent_field_idx: chan_inp = Input(field_idx.shape[1:]) channel = Masking(mask_value=self.mask_value)(chan_inp) - channel = GRU(self.rnn_layer_nodes)(channel) + if self.recurrent_unit_type == "GRU": + channel = GRU(self.rnn_layer_nodes)(channel) + elif self.recurrent_unit_type == "SimpleRNN": + channel = SimpleRNN(self.rnn_layer_nodes)(channel) + else: + raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type)) + logger.info("Added {} unit".format(self.recurrent_unit_type)) # TODO: configure dropout for recurrent layers #channel = Dropout(0.3)(channel) rnn_inputs.append(chan_inp) -- GitLab