From d09ab3d69b852869a93bc7cc374f6730af3e54e7 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 14 Aug 2018 15:40:35 +0200
Subject: [PATCH] model and training for RNN working

---
 toolkit.py | 128 ++++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 92 insertions(+), 36 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 37c77ed..f1e9e4f 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -31,9 +31,8 @@ import h5py
 from sklearn.preprocessing import StandardScaler, RobustScaler
 from sklearn.externals import joblib
 from sklearn.metrics import roc_curve, auc
-from keras.models import Sequential
-from keras.layers import Dense, Dropout
-from keras.models import model_from_json
+from keras.models import Sequential, Model, model_from_json
+from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate
 from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
 from keras.optimizers import SGD
 import keras.optimizers
@@ -649,37 +648,41 @@ class ClassificationProject(object):
             # last layer is one neuron (binary classification)
             self._model.add(Dense(1, activation=self.activation_function_output))
 
-            logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
-            Optimizer = getattr(keras.optimizers, self.optimizer)
-            optimizer = Optimizer(**self.optimizer_opts)
-            logger.info("Compile model")
-            rn_state = np.random.get_state()
-            np.random.seed(self.random_seed)
-            self._model.compile(optimizer=optimizer,
-                                loss=self.loss,
-                                weighted_metrics=['accuracy']
-            )
-            np.random.set_state(rn_state)
+            self._compile_or_load_model()
 
-            if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
-                if self.is_training:
-                    continue_training = self.query_yn("Found previously trained weights - "
-                                                      "continue training (choosing N will restart)? (Y/N) ")
-                else:
-                    continue_training = True
-                if continue_training:
-                    self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
-                    logger.info("Found and loaded previously trained weights")
-                else:
-                    logger.info("Starting completely new model")
-            else:
-                logger.info("No weights found, starting completely new model")
+        return self._model
 
-            # dump to json for documentation
-            with open(os.path.join(self.project_dir, "model.json"), "w") as of:
-                of.write(self._model.to_json())
 
-        return self._model
+    def _compile_or_load_model(self):
+        logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
+        Optimizer = getattr(keras.optimizers, self.optimizer)
+        optimizer = Optimizer(**self.optimizer_opts)
+        logger.info("Compile model")
+        rn_state = np.random.get_state()
+        np.random.seed(self.random_seed)
+        self._model.compile(optimizer=optimizer,
+                            loss=self.loss,
+                            weighted_metrics=['accuracy']
+        )
+        np.random.set_state(rn_state)
+
+        if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
+            if self.is_training:
+                continue_training = self.query_yn("Found previously trained weights - "
+                                                  "continue training (choosing N will restart)? (Y/N) ")
+            else:
+                continue_training = True
+            if continue_training:
+                self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
+                logger.info("Found and loaded previously trained weights")
+            else:
+                logger.info("Starting completely new model")
+        else:
+            logger.info("No weights found, starting completely new model")
+
+        # dump to json for documentation
+        with open(os.path.join(self.project_dir, "model.json"), "w") as of:
+            of.write(self._model.to_json())
 
 
     @property
@@ -1413,18 +1416,63 @@ class ClassificationProjectRNN(ClassificationProject):
 
 
     @property
-    def model():
-        pass
+    def model(self):
+        if self._model is None:
+            # following the setup from the tutorial:
+            # https://github.com/YaleATLAS/CERNDeepLearningTutorial
+            rnn_inputs = []
+            rnn_channels = []
+            for field_idx in self.recurrent_field_idx:
+                chan_inp = Input(field_idx.shape[1:])
+                channel = Masking(mask_value=self.mask_value)(chan_inp)
+                channel = GRU(32)(channel)
+                # TODO: configure dropout for recurrent layers
+                #channel = Dropout(0.3)(channel)
+                rnn_inputs.append(chan_inp)
+                rnn_channels.append(channel)
+            flat_input = Input((len(self.flat_fields),))
+            if self.dropout_input is None:
+                flat_channel = Dropout(rate=self.dropout_input)(flat_input)
+            else:
+                flat_channel = flat_input
+            combined = concatenate(rnn_channels+[flat_channel])
+            for node_count, dropout_fraction in zip(self.nodes, self.dropout):
+                combined = Dense(node_count, activation=self.activation_function)(combined)
+                if (dropout_fraction is not None) and (dropout_fraction > 0):
+                    combined = Dropout(rate=dropout_fraction)(combined)
+            combined = Dense(1, activation=self.activation_function_output)(combined)
+            self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined)
+            self._compile_or_load_model()
+        else:
+            return self._model
+
+
+    def train(self, epochs=10):
+        try:
+            self.shuffle_training_data() # needed here too, in order to get correct validation data
+            self.is_training = True
+            logger.info("Training on batches for RNN")
+            # note: the batches have class_weight already applied
+            self.model.fit_generator(self.yield_batch(),
+                                     steps_per_epoch=int(len(self.training_data[0])/self.batch_size),
+                                     epochs=epochs,
+                                     validation_data=self.class_weighted_validation_data,
+                                     callbacks=self.callbacks_list)
+            self.is_training = False
+        except KeyboardInterrupt:
+            logger.info("Interrupt training - continue with rest")
+        logger.info("Save history")
+        self._dump_history()
 
 
     def get_input_list(self, x):
         "Format the input starting from flat ntuple"
         x_input = []
-        x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
-        x_input.append(x_flat)
         for field_idx in self.recurrent_field_idx:
-            x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape)
+            x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:])
             x_input.append(x_recurrent)
+        x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
+        x_input.append(x_flat)
         return x_input
 
 
@@ -1442,6 +1490,14 @@ class ClassificationProjectRNN(ClassificationProject):
                        w_batch*np.array(self.class_weight)[y_batch.astype(int)])
 
 
+    @property
+    def class_weighted_validation_data(self):
+        "class weighted validation data. Attention: Shuffle training data before using this!"
+        x_val, y_val, w_val = super(ClassificationProjectRNN, self).class_weighted_validation_data
+        x_val_input = self.get_input_list(x_val)
+        return x_val_input, y_val, w_val
+
+
 if __name__ == "__main__":
 
     logging.basicConfig()
-- 
GitLab