From 0a377f2aa4579dd22b7e96e8a7ac6f33e3142986 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 29 Nov 2018 14:26:41 +0100 Subject: [PATCH] adding BaseLogger and History callbacks --- toolkit.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/toolkit.py b/toolkit.py index c58f005..d16bb2a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -37,7 +37,7 @@ from sklearn.utils.extmath import stable_cumsum from sklearn.model_selection import KFold from keras.models import Sequential, Model, model_from_json from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN -from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList +from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList, BaseLogger from keras.optimizers import SGD import keras.optimizers from keras.utils.vis_utils import model_to_dot @@ -1604,7 +1604,7 @@ class ClassificationProject(object): hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] return hist_dict - def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None): + def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None, loss_key="loss"): """ Plot the value of the loss function for each epoch @@ -1616,14 +1616,14 @@ class ClassificationProject(object): else: hist_dict = self.history.history - if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict): + if (not loss_key in hist_dict) or (not 'val_'+loss_key in hist_dict): logger.warning("No previous history found for plotting, try global history") hist_dict = self.csv_hist logger.info("Plot losses") - plt.plot(hist_dict['loss']) - plt.plot(hist_dict['val_loss']) - plt.ylabel('loss') + plt.plot(hist_dict[loss_key]) + plt.plot(hist_dict['val_'+loss_key]) + plt.ylabel(loss_key) plt.xlabel('epoch') plt.legend(['training data','validation data'], loc='upper left') if log: @@ -2219,7 +2219,7 @@ class ClassificationProjectDecorr(ClassificationProject): return self._model_adv - def train(self, epochs=10): + def train(self, epochs=10, skip_checkpoint=False): """ Train classifier and adversary concurrently. Most of the garbage in this code block is just organising stuff to get all the keras callbacks @@ -2227,17 +2227,20 @@ class ClassificationProjectDecorr(ClassificationProject): """ batch_generator = self.yield_batch() - metric_list = [] out_labels = self.model.metrics_names + self.model.history = History() callback_metrics = out_labels + ['val_' + n for n in out_labels] - callbacks = CallbackList(self.callbacks_list) + callbacks = CallbackList( + [BaseLogger()] + + self.callbacks_list + + [self.model.history]) callbacks.set_model(self.model) callbacks.set_params({ 'epochs': epochs, 'steps': self.steps_per_epoch, 'verbose': self.verbose, #'do_validation': do_validation, - 'do_validation': False, + 'do_validation': True, 'metrics': callback_metrics, }) self.model.stop_training = False @@ -2264,27 +2267,23 @@ class ClassificationProjectDecorr(ClassificationProject): self.model_adv.train_on_batch( x, y[1:], sample_weight=w[1:] ) - - batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics)) - if metrics is None: - metrics = batch_metrics - else: - metrics = np.concatenate([metrics, batch_metrics]) - avg_metrics = np.mean(metrics, axis=0) outs = list(batch_metrics) for l, o in zip(out_labels, outs): - batch_logs[l] = o + batch_logs[l] = float(o) callbacks.on_batch_end(batch_id, batch_logs) - metric_list.append(avg_metrics) val_metrics = self.model.test_on_batch(*self.validation_data) val_outs = list(val_metrics) for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o + epoch_logs['val_' + l] = float(o) callbacks.on_epoch_end(epoch, epoch_logs) if self.model.stop_training: break callbacks.on_train_end() - return metric_list + + if not skip_checkpoint: + self.checkpoint_model() + + return self.model.history if __name__ == "__main__": -- GitLab