From d09ab3d69b852869a93bc7cc374f6730af3e54e7 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 14 Aug 2018 15:40:35 +0200 Subject: [PATCH] model and training for RNN working --- toolkit.py | 128 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/toolkit.py b/toolkit.py index 37c77ed..f1e9e4f 100755 --- a/toolkit.py +++ b/toolkit.py @@ -31,9 +31,8 @@ import h5py from sklearn.preprocessing import StandardScaler, RobustScaler from sklearn.externals import joblib from sklearn.metrics import roc_curve, auc -from keras.models import Sequential -from keras.layers import Dense, Dropout -from keras.models import model_from_json +from keras.models import Sequential, Model, model_from_json +from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard from keras.optimizers import SGD import keras.optimizers @@ -649,37 +648,41 @@ class ClassificationProject(object): # last layer is one neuron (binary classification) self._model.add(Dense(1, activation=self.activation_function_output)) - logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) - Optimizer = getattr(keras.optimizers, self.optimizer) - optimizer = Optimizer(**self.optimizer_opts) - logger.info("Compile model") - rn_state = np.random.get_state() - np.random.seed(self.random_seed) - self._model.compile(optimizer=optimizer, - loss=self.loss, - weighted_metrics=['accuracy'] - ) - np.random.set_state(rn_state) + self._compile_or_load_model() - if os.path.exists(os.path.join(self.project_dir, "weights.h5")): - if self.is_training: - continue_training = self.query_yn("Found previously trained weights - " - "continue training (choosing N will restart)? (Y/N) ") - else: - continue_training = True - if continue_training: - self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - logger.info("Found and loaded previously trained weights") - else: - logger.info("Starting completely new model") - else: - logger.info("No weights found, starting completely new model") + return self._model - # dump to json for documentation - with open(os.path.join(self.project_dir, "model.json"), "w") as of: - of.write(self._model.to_json()) - return self._model + def _compile_or_load_model(self): + logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) + Optimizer = getattr(keras.optimizers, self.optimizer) + optimizer = Optimizer(**self.optimizer_opts) + logger.info("Compile model") + rn_state = np.random.get_state() + np.random.seed(self.random_seed) + self._model.compile(optimizer=optimizer, + loss=self.loss, + weighted_metrics=['accuracy'] + ) + np.random.set_state(rn_state) + + if os.path.exists(os.path.join(self.project_dir, "weights.h5")): + if self.is_training: + continue_training = self.query_yn("Found previously trained weights - " + "continue training (choosing N will restart)? (Y/N) ") + else: + continue_training = True + if continue_training: + self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) + logger.info("Found and loaded previously trained weights") + else: + logger.info("Starting completely new model") + else: + logger.info("No weights found, starting completely new model") + + # dump to json for documentation + with open(os.path.join(self.project_dir, "model.json"), "w") as of: + of.write(self._model.to_json()) @property @@ -1413,18 +1416,63 @@ class ClassificationProjectRNN(ClassificationProject): @property - def model(): - pass + def model(self): + if self._model is None: + # following the setup from the tutorial: + # https://github.com/YaleATLAS/CERNDeepLearningTutorial + rnn_inputs = [] + rnn_channels = [] + for field_idx in self.recurrent_field_idx: + chan_inp = Input(field_idx.shape[1:]) + channel = Masking(mask_value=self.mask_value)(chan_inp) + channel = GRU(32)(channel) + # TODO: configure dropout for recurrent layers + #channel = Dropout(0.3)(channel) + rnn_inputs.append(chan_inp) + rnn_channels.append(channel) + flat_input = Input((len(self.flat_fields),)) + if self.dropout_input is None: + flat_channel = Dropout(rate=self.dropout_input)(flat_input) + else: + flat_channel = flat_input + combined = concatenate(rnn_channels+[flat_channel]) + for node_count, dropout_fraction in zip(self.nodes, self.dropout): + combined = Dense(node_count, activation=self.activation_function)(combined) + if (dropout_fraction is not None) and (dropout_fraction > 0): + combined = Dropout(rate=dropout_fraction)(combined) + combined = Dense(1, activation=self.activation_function_output)(combined) + self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined) + self._compile_or_load_model() + else: + return self._model + + + def train(self, epochs=10): + try: + self.shuffle_training_data() # needed here too, in order to get correct validation data + self.is_training = True + logger.info("Training on batches for RNN") + # note: the batches have class_weight already applied + self.model.fit_generator(self.yield_batch(), + steps_per_epoch=int(len(self.training_data[0])/self.batch_size), + epochs=epochs, + validation_data=self.class_weighted_validation_data, + callbacks=self.callbacks_list) + self.is_training = False + except KeyboardInterrupt: + logger.info("Interrupt training - continue with rest") + logger.info("Save history") + self._dump_history() def get_input_list(self, x): "Format the input starting from flat ntuple" x_input = [] - x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]] - x_input.append(x_flat) for field_idx in self.recurrent_field_idx: - x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape) + x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:]) x_input.append(x_recurrent) + x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]] + x_input.append(x_flat) return x_input @@ -1442,6 +1490,14 @@ class ClassificationProjectRNN(ClassificationProject): w_batch*np.array(self.class_weight)[y_batch.astype(int)]) + @property + def class_weighted_validation_data(self): + "class weighted validation data. Attention: Shuffle training data before using this!" + x_val, y_val, w_val = super(ClassificationProjectRNN, self).class_weighted_validation_data + x_val_input = self.get_input_list(x_val) + return x_val_input, y_val, w_val + + if __name__ == "__main__": logging.basicConfig() -- GitLab