diff --git a/toolkit.py b/toolkit.py index 5b33e0052d145a22a6210d9c4e831f29c260de61..3498ab30d6a78a43f05ed2ecdf087e58c63d6be0 100755 --- a/toolkit.py +++ b/toolkit.py @@ -34,7 +34,7 @@ from sklearn.externals import joblib from sklearn.metrics import roc_curve, auc from sklearn.utils.extmath import stable_cumsum from keras.models import Sequential, Model, model_from_json -from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate +from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate, SimpleRNN from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard from keras.optimizers import SGD import keras.optimizers @@ -716,7 +716,7 @@ class ClassificationProject(object): # plot model with open(os.path.join(self.project_dir, "model.svg"), "wb") as of: - of.write(model_to_dot(self._model).create("dot", format="svg")) + of.write(model_to_dot(self._model, show_shapes=True).create("dot", format="svg")) @property @@ -1598,6 +1598,7 @@ class ClassificationProjectRNN(ClassificationProject): recurrent_field_names=None, rnn_layer_nodes=32, mask_value=-999, + recurrent_unit_type="GRU", **kwargs): """ recurrent_field_names example: @@ -1616,6 +1617,7 @@ class ClassificationProjectRNN(ClassificationProject): self.recurrent_field_names = [] self.rnn_layer_nodes = rnn_layer_nodes self.mask_value = mask_value + self.recurrent_unit_type = recurrent_unit_type # convert to of indices self.recurrent_field_idx = [] @@ -1664,7 +1666,13 @@ class ClassificationProjectRNN(ClassificationProject): for field_idx in self.recurrent_field_idx: chan_inp = Input(field_idx.shape[1:]) channel = Masking(mask_value=self.mask_value)(chan_inp) - channel = GRU(self.rnn_layer_nodes)(channel) + if self.recurrent_unit_type == "GRU": + channel = GRU(self.rnn_layer_nodes)(channel) + elif self.recurrent_unit_type == "SimpleRNN": + channel = SimpleRNN(self.rnn_layer_nodes)(channel) + else: + raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type)) + logger.info("Added {} unit".format(self.recurrent_unit_type)) # TODO: configure dropout for recurrent layers #channel = Dropout(0.3)(channel) rnn_inputs.append(chan_inp)