From beee5503b517f5ac813d34e6150e3df231c7e860 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 14 Aug 2018 17:43:43 +0200 Subject: [PATCH] load and plot before training RNN --- toolkit.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/toolkit.py b/toolkit.py index 57ef3f3..1fcf30c 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1361,6 +1361,7 @@ class ClassificationProjectRNN(ClassificationProject): def __init__(self, name, recurrent_field_names=None, + rnn_layer_nodes=32, mask_value=-999, **kwargs): """ @@ -1376,6 +1377,7 @@ class ClassificationProjectRNN(ClassificationProject): self.recurrent_field_names = recurrent_field_names if self.recurrent_field_names is None: self.recurrent_field_names = [] + self.rnn_layer_nodes = rnn_layer_nodes self.mask_value = mask_value # convert to of indices @@ -1425,7 +1427,7 @@ class ClassificationProjectRNN(ClassificationProject): for field_idx in self.recurrent_field_idx: chan_inp = Input(field_idx.shape[1:]) channel = Masking(mask_value=self.mask_value)(chan_inp) - channel = GRU(32)(channel) + channel = GRU(self.rnn_layer_nodes)(channel) # TODO: configure dropout for recurrent layers #channel = Dropout(0.3)(channel) rnn_inputs.append(chan_inp) @@ -1443,11 +1445,15 @@ class ClassificationProjectRNN(ClassificationProject): combined = Dense(1, activation=self.activation_function_output)(combined) self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined) self._compile_or_load_model() - else: - return self._model + return self._model def train(self, epochs=10): + self.load() + + for branch_index, branch in enumerate(self.fields): + self.plot_input(branch_index) + try: self.shuffle_training_data() # needed here too, in order to get correct validation data self.is_training = True -- GitLab