From 71e626f26e2e54cfb327cbe3168d2389bf95c52a Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Wed, 9 May 2018 09:28:33 +0200
Subject: [PATCH] output layer activation configurable

---
 toolkit.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/toolkit.py b/toolkit.py
index 191991f..473a617 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -81,6 +81,8 @@ class ClassificationProject(object):
 
     :param activation_function: activation function in the hidden layers
 
+    :param activation_function_output: activation function in the output layer
+
     :param out_dir: base directory in which the project directories should be stored
 
     :param scaler_type: sklearn scaler class name to transform the data before training (options: "StandardScaler", "RobustScaler")
@@ -136,6 +138,7 @@ class ClassificationProject(object):
                         batch_size=128,
                         validation_split=0.33,
                         activation_function='relu',
+                        activation_function_output='sigmoid',
                         project_dir=None,
                         scaler_type="RobustScaler",
                         step_signal=2,
@@ -158,6 +161,7 @@ class ClassificationProject(object):
         self.batch_size = batch_size
         self.validation_split = validation_split
         self.activation_function = activation_function
+        self.activation_function_output = activation_function_output
         self.scaler_type = scaler_type
         self.step_signal = step_signal
         self.step_bkg = step_bkg
@@ -440,7 +444,7 @@ class ClassificationProject(object):
             for layer_number in range(self.layers-1):
                 self._model.add(Dense(self.nodes, activation=self.activation_function))
             # last layer is one neuron (binary classification)
-            self._model.add(Dense(1, activation='sigmoid'))
+            self._model.add(Dense(1, activation=self.activation_function_output))
             logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
             Optimizer = getattr(keras.optimizers, self.optimizer)
             optimizer = Optimizer(**self.optimizer_opts)
-- 
GitLab