Skip to content
Snippets Groups Projects
Commit beee5503 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

load and plot before training RNN

parent 7c1241ff
No related branches found
No related tags found
No related merge requests found
...@@ -1361,6 +1361,7 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1361,6 +1361,7 @@ class ClassificationProjectRNN(ClassificationProject):
def __init__(self, name, def __init__(self, name,
recurrent_field_names=None, recurrent_field_names=None,
rnn_layer_nodes=32,
mask_value=-999, mask_value=-999,
**kwargs): **kwargs):
""" """
...@@ -1376,6 +1377,7 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1376,6 +1377,7 @@ class ClassificationProjectRNN(ClassificationProject):
self.recurrent_field_names = recurrent_field_names self.recurrent_field_names = recurrent_field_names
if self.recurrent_field_names is None: if self.recurrent_field_names is None:
self.recurrent_field_names = [] self.recurrent_field_names = []
self.rnn_layer_nodes = rnn_layer_nodes
self.mask_value = mask_value self.mask_value = mask_value
# convert to of indices # convert to of indices
...@@ -1425,7 +1427,7 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1425,7 +1427,7 @@ class ClassificationProjectRNN(ClassificationProject):
for field_idx in self.recurrent_field_idx: for field_idx in self.recurrent_field_idx:
chan_inp = Input(field_idx.shape[1:]) chan_inp = Input(field_idx.shape[1:])
channel = Masking(mask_value=self.mask_value)(chan_inp) channel = Masking(mask_value=self.mask_value)(chan_inp)
channel = GRU(32)(channel) channel = GRU(self.rnn_layer_nodes)(channel)
# TODO: configure dropout for recurrent layers # TODO: configure dropout for recurrent layers
#channel = Dropout(0.3)(channel) #channel = Dropout(0.3)(channel)
rnn_inputs.append(chan_inp) rnn_inputs.append(chan_inp)
...@@ -1443,11 +1445,15 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1443,11 +1445,15 @@ class ClassificationProjectRNN(ClassificationProject):
combined = Dense(1, activation=self.activation_function_output)(combined) combined = Dense(1, activation=self.activation_function_output)(combined)
self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined) self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined)
self._compile_or_load_model() self._compile_or_load_model()
else: return self._model
return self._model
def train(self, epochs=10): def train(self, epochs=10):
self.load()
for branch_index, branch in enumerate(self.fields):
self.plot_input(branch_index)
try: try:
self.shuffle_training_data() # needed here too, in order to get correct validation data self.shuffle_training_data() # needed here too, in order to get correct validation data
self.is_training = True self.is_training = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment