diff --git a/toolkit.py b/toolkit.py index cb32d90c675ebd333999c09a1295c123a9480bde..620d3f6f5315036631cce6ba770bd2a461d1d4ac 100755 --- a/toolkit.py +++ b/toolkit.py @@ -734,8 +734,14 @@ class ClassificationProject(object): # one output node for binary classification output_layer = Dense(1, activation=self.activation_function_output)(hidden_layer) + outputs = [output_layer] - self._model = Model(inputs=[input_layer], outputs=[output_layer]) + # optional regression targets + for target_field in self.target_fields: + extra_target = Dense(1, activation="linear", name="target_{}".format(target_field))(hidden_layer) + outputs.append(extra_target) + + self._model = Model(inputs=[input_layer], outputs=outputs) self._compile_or_load_model() return self._model @@ -1861,7 +1867,14 @@ class ClassificationProjectRNN(ClassificationProject): if (dropout_fraction is not None) and (dropout_fraction > 0): combined = Dropout(rate=dropout_fraction)(combined) combined = Dense(1, activation=self.activation_function_output)(combined) - self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined) + outputs = [combined] + + # optional regression targets + for target_field in self.target_fields: + extra_target = Dense(1, activation="linear", name="target_{}".format(target_field))(combined) + outputs.append(extra_target) + + self._model = Model(inputs=rnn_inputs+[flat_input], outputs=outputs) self._compile_or_load_model() return self._model