Skip to content
Snippets Groups Projects
Commit 87b25c46 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

allow different number of nodes per layer

parent 96a77543
No related branches found
No related tags found
No related merge requests found
......@@ -106,7 +106,7 @@ class ClassificationProject(object):
:param layers: number of layers in the neural network
:param nodes: number of nodes in each layer
:param nodes: list number of nodes in each layer. If only a single number is given, use this number for every layer
:param dropout: dropout fraction after each hidden layer. Set to None for no Dropout
......@@ -230,6 +230,12 @@ class ClassificationProject(object):
self.identifiers = identifiers
self.layers = layers
self.nodes = nodes
if not isinstance(self.nodes, list):
self.nodes = [self.nodes for i in range(self.layers)]
if len(self.nodes) != self.layers:
self.layers = len(self.nodes)
logger.warning("Number of layers not equal to the given nodes "
"per layer - adjusted to " + str(self.layers))
self.dropout = dropout
self.batch_size = batch_size
self.validation_split = validation_split
......@@ -551,10 +557,10 @@ class ClassificationProject(object):
self._model = Sequential()
# first hidden layer
self._model.add(Dense(self.nodes, input_dim=len(self.branches), activation=self.activation_function))
self._model.add(Dense(self.nodes[0], input_dim=len(self.branches), activation=self.activation_function))
# the other hidden layers
for layer_number in range(self.layers-1):
self._model.add(Dense(self.nodes, activation=self.activation_function))
for node_count, layer_number in zip(self.nodes[1:], range(self.layers-1)):
self._model.add(Dense(node_count, activation=self.activation_function))
if self.dropout is not None:
self._model.add(Dropout(rate=self.dropout))
# last layer is one neuron (binary classification)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment