diff --git a/toolkit.py b/toolkit.py index fb51f65d7c944fe2ee1c10f015f69012bbe71851..d673121cf9ad7fdb264a8e6cfdcf3c04438c563a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -791,14 +791,22 @@ class ClassificationProject(object): if (dropout_fraction is not None) and (dropout_fraction > 0): hidden_layer = Dropout(rate=dropout_fraction)(hidden_layer) - # one output node for binary classification - output_layer = Dense(1, activation=self.activation_function_output)(hidden_layer) - outputs = [output_layer] - # optional regression targets + extra_targets = [] for target_field in self.target_fields: extra_target = Dense(1, activation="linear", name="target_{}".format(target_field))(hidden_layer) - outputs.append(extra_target) + extra_targets.append(extra_target) + + if not self.target_fields: + # one output node for binary classification + output_layer = Dense(1, activation=self.activation_function_output)(hidden_layer) + outputs = [output_layer] + else: + # add another hidden layer on top of the regression targets and previous hidden layers + merge = concatenate([hidden_layer]+extra_targets) + hidden_layer2 = Dense(64, activation=self.activation_function)(merge) + output_class = Dense(1, activation=self.activation_function_output)(hidden_layer2) + outputs = [output_class]+extra_targets self._model = Model(inputs=[input_layer], outputs=outputs) self._compile_or_load_model()