Skip to content
Snippets Groups Projects
Commit ff7f1301 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adding lstm option for RNN

parent 49243434
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,7 @@ from sklearn.externals import joblib
from sklearn.metrics import roc_curve, auc
from sklearn.utils.extmath import stable_cumsum
from keras.models import Sequential, Model, model_from_json
from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN
from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
from keras.optimizers import SGD
import keras.optimizers
......@@ -1736,6 +1736,8 @@ class ClassificationProjectRNN(ClassificationProject):
channel = GRU(self.rnn_layer_nodes)(channel)
elif self.recurrent_unit_type == "SimpleRNN":
channel = SimpleRNN(self.rnn_layer_nodes)(channel)
elif self.recurrent_unit_type == "LSTM":
channel = LSTM(self.rnn_layer_nodes)(channel)
else:
raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type))
logger.info("Added {} unit".format(self.recurrent_unit_type))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment