Skip to content
Snippets Groups Projects
Commit 71e626f2 authored by Nikolai's avatar Nikolai
Browse files

output layer activation configurable

parent 4bd10c9b
No related branches found
No related tags found
No related merge requests found
...@@ -81,6 +81,8 @@ class ClassificationProject(object): ...@@ -81,6 +81,8 @@ class ClassificationProject(object):
:param activation_function: activation function in the hidden layers :param activation_function: activation function in the hidden layers
:param activation_function_output: activation function in the output layer
:param out_dir: base directory in which the project directories should be stored :param out_dir: base directory in which the project directories should be stored
:param scaler_type: sklearn scaler class name to transform the data before training (options: "StandardScaler", "RobustScaler") :param scaler_type: sklearn scaler class name to transform the data before training (options: "StandardScaler", "RobustScaler")
...@@ -136,6 +138,7 @@ class ClassificationProject(object): ...@@ -136,6 +138,7 @@ class ClassificationProject(object):
batch_size=128, batch_size=128,
validation_split=0.33, validation_split=0.33,
activation_function='relu', activation_function='relu',
activation_function_output='sigmoid',
project_dir=None, project_dir=None,
scaler_type="RobustScaler", scaler_type="RobustScaler",
step_signal=2, step_signal=2,
...@@ -158,6 +161,7 @@ class ClassificationProject(object): ...@@ -158,6 +161,7 @@ class ClassificationProject(object):
self.batch_size = batch_size self.batch_size = batch_size
self.validation_split = validation_split self.validation_split = validation_split
self.activation_function = activation_function self.activation_function = activation_function
self.activation_function_output = activation_function_output
self.scaler_type = scaler_type self.scaler_type = scaler_type
self.step_signal = step_signal self.step_signal = step_signal
self.step_bkg = step_bkg self.step_bkg = step_bkg
...@@ -440,7 +444,7 @@ class ClassificationProject(object): ...@@ -440,7 +444,7 @@ class ClassificationProject(object):
for layer_number in range(self.layers-1): for layer_number in range(self.layers-1):
self._model.add(Dense(self.nodes, activation=self.activation_function)) self._model.add(Dense(self.nodes, activation=self.activation_function))
# last layer is one neuron (binary classification) # last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid')) self._model.add(Dense(1, activation=self.activation_function_output))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer) Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts) optimizer = Optimizer(**self.optimizer_opts)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment