From 347835717dc714c96bd4466486867438e6ceb229 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Thu, 29 Nov 2018 10:19:10 +0100
Subject: [PATCH] use keras callbacks in adversarial training

---
 toolkit.py | 43 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 42 insertions(+), 1 deletion(-)

diff --git a/toolkit.py b/toolkit.py
index 9e903ca..0403f9b 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -37,7 +37,7 @@ from sklearn.utils.extmath import stable_cumsum
 from sklearn.model_selection import KFold
 from keras.models import Sequential, Model, model_from_json
 from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
-from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
+from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList
 from keras.optimizers import SGD
 import keras.optimizers
 from keras.utils.vis_utils import model_to_dot
@@ -2220,29 +2220,70 @@ class ClassificationProjectDecorr(ClassificationProject):
 
 
     def train(self, epochs=10):
+        """
+        Train classifier and adversary concurrently. Most of the garbage in this
+        code block is just organising stuff to get all the keras callbacks
+        right. That code is mostly copied over from keras `engine/training_generator.py`.
+        """
+
         batch_generator = self.yield_batch()
         metric_list = []
+        out_labels = self.model.metrics_names
+        callback_metrics = out_labels + ['val_' + n for n in out_labels]
+        callbacks = CallbackList(self.callbacks_list)
+        callbacks.set_model(self.model)
+        callbacks.set_params({
+            'epochs': epochs,
+            'steps': self.steps_per_epoch,
+            'verbose': self.verbose,
+            #'do_validation': do_validation,
+            'do_validation': False,
+            'metrics': callback_metrics,
+        })
+        self.model.stop_training = False
+        callbacks.on_train_begin()
+        epoch_logs = {}
         for epoch in range(epochs):
+            callbacks.on_epoch_begin(epoch)
             logger.info("Fitting epoch {}".format(epoch))
             metrics = None
             avg_metrics = None
             for batch_id in tqdm(range(self.steps_per_epoch)):
                 x, y, w = next(batch_generator)
+                batch_logs = {}
+                batch_logs['batch'] = batch_id
+                batch_logs['size'] = len(x)
+                callbacks.on_batch_begin(batch_id, batch_logs)
+
                 # fit the classifier
                 batch_metrics = self.model.train_on_batch(
                     x, y, sample_weight=w
                 )
+
                 # fit the adversary
                 self.model_adv.train_on_batch(
                     x, y[1:], sample_weight=w[1:]
                 )
+
                 batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics))
                 if metrics is None:
                     metrics = batch_metrics
                 else:
                     metrics = np.concatenate([metrics, batch_metrics])
                 avg_metrics = np.mean(metrics, axis=0)
+                outs = list(batch_metrics)
+                for l, o in zip(out_labels, outs):
+                    batch_logs[l] = o
+                callbacks.on_batch_end(batch_id, batch_logs)
             metric_list.append(avg_metrics)
+            val_metrics = self.model.test_on_batch(*self.validation_data)
+            val_outs = list(val_metrics)
+            for l, o in zip(out_labels, val_outs):
+                epoch_logs['val_' + l] = o
+            callbacks.on_epoch_end(epoch, epoch_logs)
+            if self.model.stop_training:
+                break
+        callbacks.on_train_end()
         return metric_list
 
 
-- 
GitLab