From ff7f130194b567aa7d8b7ccdb1b06db1fb1df483 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 4 Sep 2018 09:44:05 +0200 Subject: [PATCH] adding lstm option for RNN --- toolkit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/toolkit.py b/toolkit.py index d124ef9..f108d1d 100755 --- a/toolkit.py +++ b/toolkit.py @@ -34,7 +34,7 @@ from sklearn.externals import joblib from sklearn.metrics import roc_curve, auc from sklearn.utils.extmath import stable_cumsum from keras.models import Sequential, Model, model_from_json -from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN +from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard from keras.optimizers import SGD import keras.optimizers @@ -1736,6 +1736,8 @@ class ClassificationProjectRNN(ClassificationProject): channel = GRU(self.rnn_layer_nodes)(channel) elif self.recurrent_unit_type == "SimpleRNN": channel = SimpleRNN(self.rnn_layer_nodes)(channel) + elif self.recurrent_unit_type == "LSTM": + channel = LSTM(self.rnn_layer_nodes)(channel) else: raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type)) logger.info("Added {} unit".format(self.recurrent_unit_type)) -- GitLab