Skip to content
Snippets Groups Projects
Commit b9e191be authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Extended dropout support

parent 5ce28def
No related branches found
No related tags found
No related merge requests found
...@@ -110,7 +110,9 @@ class ClassificationProject(object): ...@@ -110,7 +110,9 @@ class ClassificationProject(object):
:param nodes: list number of nodes in each layer. If only a single number is given, use this number for every layer :param nodes: list number of nodes in each layer. If only a single number is given, use this number for every layer
:param dropout: dropout fraction after each hidden layer. Set to None for no Dropout :param dropout: dropout fraction after each hidden layer. You can also pass a list for dropout fractions for each layer. Set to None for no Dropout.
:param dropout_input: dropout fraction for the input layer. Set to None for no Dropout.
:param batch_size: size of the training batches :param batch_size: size of the training batches
...@@ -196,6 +198,7 @@ class ClassificationProject(object): ...@@ -196,6 +198,7 @@ class ClassificationProject(object):
layers=3, layers=3,
nodes=64, nodes=64,
dropout=None, dropout=None,
dropout_input=None,
batch_size=128, batch_size=128,
validation_split=0.33, validation_split=0.33,
activation_function='relu', activation_function='relu',
...@@ -243,6 +246,11 @@ class ClassificationProject(object): ...@@ -243,6 +246,11 @@ class ClassificationProject(object):
logger.warning("Number of layers not equal to the given nodes " logger.warning("Number of layers not equal to the given nodes "
"per layer - adjusted to " + str(self.layers)) "per layer - adjusted to " + str(self.layers))
self.dropout = dropout self.dropout = dropout
if not isinstance(self.dropout, list):
self.dropout = [self.dropout for i in range(self.layers)]
if len(self.dropout) != self.layers:
raise ValueError("List of dropout fractions has to be of equal size as the number of layers!")
self.dropout_input = dropout_input
self.batch_size = batch_size self.batch_size = batch_size
self.validation_split = validation_split self.validation_split = validation_split
self.activation_function = activation_function self.activation_function = activation_function
...@@ -588,15 +596,21 @@ class ClassificationProject(object): ...@@ -588,15 +596,21 @@ class ClassificationProject(object):
self._model = Sequential() self._model = Sequential()
# first hidden layer if self.dropout_input is None:
self._model.add(Dense(self.nodes[0], input_dim=len(self.fields), activation=self.activation_function)) self._model.add(Dense(self.nodes[0], input_dim=len(self.fields), activation=self.activation_function))
# the other hidden layers # in case of no Dropout we already have the first hidden layer
for node_count, layer_number in zip(self.nodes[1:], range(self.layers-1)): start_layer = 1
else:
self._model.add(Dropout(rate=self.dropout_input, input_shape=(len(self.fields),)))
start_layer = 0
# the (other) hidden layers
for node_count, dropout_fraction in zip(self.nodes[start_layer:], self.dropout[start_layer:]):
self._model.add(Dense(node_count, activation=self.activation_function)) self._model.add(Dense(node_count, activation=self.activation_function))
if self.dropout is not None: if dropout_fraction > 0:
self._model.add(Dropout(rate=self.dropout)) self._model.add(Dropout(rate=dropout_fraction))
# last layer is one neuron (binary classification) # last layer is one neuron (binary classification)
self._model.add(Dense(1, activation=self.activation_function_output)) self._model.add(Dense(1, activation=self.activation_function_output))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer) Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts) optimizer = Optimizer(**self.optimizer_opts)
...@@ -607,9 +621,11 @@ class ClassificationProject(object): ...@@ -607,9 +621,11 @@ class ClassificationProject(object):
loss=self.loss, loss=self.loss,
metrics=['accuracy']) metrics=['accuracy'])
np.random.set_state(rn_state) np.random.set_state(rn_state)
if os.path.exists(os.path.join(self.project_dir, "weights.h5")): if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
if self.is_training: if self.is_training:
continue_training = self.query_yn("Found previously trained weights - continue training (choosing N will restart)? (Y/N) ") continue_training = self.query_yn("Found previously trained weights - "
"continue training (choosing N will restart)? (Y/N) ")
else: else:
continue_training = True continue_training = True
if continue_training: if continue_training:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment