From ff7f130194b567aa7d8b7ccdb1b06db1fb1df483 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 4 Sep 2018 09:44:05 +0200
Subject: [PATCH] adding lstm option for RNN

---
 toolkit.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/toolkit.py b/toolkit.py
index d124ef9..f108d1d 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -34,7 +34,7 @@ from sklearn.externals import joblib
 from sklearn.metrics import roc_curve, auc
 from sklearn.utils.extmath import stable_cumsum
 from keras.models import Sequential, Model, model_from_json
-from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN
+from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
 from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
 from keras.optimizers import SGD
 import keras.optimizers
@@ -1736,6 +1736,8 @@ class ClassificationProjectRNN(ClassificationProject):
                     channel = GRU(self.rnn_layer_nodes)(channel)
                 elif self.recurrent_unit_type == "SimpleRNN":
                     channel = SimpleRNN(self.rnn_layer_nodes)(channel)
+                elif self.recurrent_unit_type == "LSTM":
+                    channel = LSTM(self.rnn_layer_nodes)(channel)
                 else:
                     raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type))
                 logger.info("Added {} unit".format(self.recurrent_unit_type))
-- 
GitLab