diff --git a/toolkit.py b/toolkit.py index 57ef3f3c3b0d7466f4f76b176d5794b071909b8d..1fcf30cdc88aa3accd18da21947a5a0f03c97af3 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1361,6 +1361,7 @@ class ClassificationProjectRNN(ClassificationProject): def __init__(self, name, recurrent_field_names=None, + rnn_layer_nodes=32, mask_value=-999, **kwargs): """ @@ -1376,6 +1377,7 @@ class ClassificationProjectRNN(ClassificationProject): self.recurrent_field_names = recurrent_field_names if self.recurrent_field_names is None: self.recurrent_field_names = [] + self.rnn_layer_nodes = rnn_layer_nodes self.mask_value = mask_value # convert to of indices @@ -1425,7 +1427,7 @@ class ClassificationProjectRNN(ClassificationProject): for field_idx in self.recurrent_field_idx: chan_inp = Input(field_idx.shape[1:]) channel = Masking(mask_value=self.mask_value)(chan_inp) - channel = GRU(32)(channel) + channel = GRU(self.rnn_layer_nodes)(channel) # TODO: configure dropout for recurrent layers #channel = Dropout(0.3)(channel) rnn_inputs.append(chan_inp) @@ -1443,11 +1445,15 @@ class ClassificationProjectRNN(ClassificationProject): combined = Dense(1, activation=self.activation_function_output)(combined) self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined) self._compile_or_load_model() - else: - return self._model + return self._model def train(self, epochs=10): + self.load() + + for branch_index, branch in enumerate(self.fields): + self.plot_input(branch_index) + try: self.shuffle_training_data() # needed here too, in order to get correct validation data self.is_training = True