diff --git a/toolkit.py b/toolkit.py
index c58f005dbe29c774a4342edb37d392587b3d4797..d16bb2aae49f11f9113d97b9bd6c22070db3f812 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -37,7 +37,7 @@ from sklearn.utils.extmath import stable_cumsum
 from sklearn.model_selection import KFold
 from keras.models import Sequential, Model, model_from_json
 from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
-from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList
+from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList, BaseLogger
 from keras.optimizers import SGD
 import keras.optimizers
 from keras.utils.vis_utils import model_to_dot
@@ -1604,7 +1604,7 @@ class ClassificationProject(object):
             hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
         return hist_dict
 
-    def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None):
+    def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None, loss_key="loss"):
         """
         Plot the value of the loss function for each epoch
 
@@ -1616,14 +1616,14 @@ class ClassificationProject(object):
         else:
             hist_dict = self.history.history
 
-        if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict):
+        if (not loss_key in hist_dict) or (not 'val_'+loss_key in hist_dict):
             logger.warning("No previous history found for plotting, try global history")
             hist_dict = self.csv_hist
 
         logger.info("Plot losses")
-        plt.plot(hist_dict['loss'])
-        plt.plot(hist_dict['val_loss'])
-        plt.ylabel('loss')
+        plt.plot(hist_dict[loss_key])
+        plt.plot(hist_dict['val_'+loss_key])
+        plt.ylabel(loss_key)
         plt.xlabel('epoch')
         plt.legend(['training data','validation data'], loc='upper left')
         if log:
@@ -2219,7 +2219,7 @@ class ClassificationProjectDecorr(ClassificationProject):
         return self._model_adv
 
 
-    def train(self, epochs=10):
+    def train(self, epochs=10, skip_checkpoint=False):
         """
         Train classifier and adversary concurrently. Most of the garbage in this
         code block is just organising stuff to get all the keras callbacks
@@ -2227,17 +2227,20 @@ class ClassificationProjectDecorr(ClassificationProject):
         """
 
         batch_generator = self.yield_batch()
-        metric_list = []
         out_labels = self.model.metrics_names
+        self.model.history = History()
         callback_metrics = out_labels + ['val_' + n for n in out_labels]
-        callbacks = CallbackList(self.callbacks_list)
+        callbacks = CallbackList(
+            [BaseLogger()]
+            + self.callbacks_list
+            + [self.model.history])
         callbacks.set_model(self.model)
         callbacks.set_params({
             'epochs': epochs,
             'steps': self.steps_per_epoch,
             'verbose': self.verbose,
             #'do_validation': do_validation,
-            'do_validation': False,
+            'do_validation': True,
             'metrics': callback_metrics,
         })
         self.model.stop_training = False
@@ -2264,27 +2267,23 @@ class ClassificationProjectDecorr(ClassificationProject):
                 self.model_adv.train_on_batch(
                     x, y[1:], sample_weight=w[1:]
                 )
-
-                batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics))
-                if metrics is None:
-                    metrics = batch_metrics
-                else:
-                    metrics = np.concatenate([metrics, batch_metrics])
-                avg_metrics = np.mean(metrics, axis=0)
                 outs = list(batch_metrics)
                 for l, o in zip(out_labels, outs):
-                    batch_logs[l] = o
+                    batch_logs[l] = float(o)
                 callbacks.on_batch_end(batch_id, batch_logs)
-            metric_list.append(avg_metrics)
             val_metrics = self.model.test_on_batch(*self.validation_data)
             val_outs = list(val_metrics)
             for l, o in zip(out_labels, val_outs):
-                epoch_logs['val_' + l] = o
+                epoch_logs['val_' + l] = float(o)
             callbacks.on_epoch_end(epoch, epoch_logs)
             if self.model.stop_training:
                 break
         callbacks.on_train_end()
-        return metric_list
+
+        if not skip_checkpoint:
+            self.checkpoint_model()
+
+        return self.model.history
 
 
 if __name__ == "__main__":