diff --git a/toolkit.py b/toolkit.py index 9e903ca0b5cc3d3a92b726afb076d9107bd9ee65..0403f9bfa0183311d47082c294449130f49009af 100755 --- a/toolkit.py +++ b/toolkit.py @@ -37,7 +37,7 @@ from sklearn.utils.extmath import stable_cumsum from sklearn.model_selection import KFold from keras.models import Sequential, Model, model_from_json from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN -from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard +from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList from keras.optimizers import SGD import keras.optimizers from keras.utils.vis_utils import model_to_dot @@ -2220,29 +2220,70 @@ class ClassificationProjectDecorr(ClassificationProject): def train(self, epochs=10): + """ + Train classifier and adversary concurrently. Most of the garbage in this + code block is just organising stuff to get all the keras callbacks + right. That code is mostly copied over from keras `engine/training_generator.py`. + """ + batch_generator = self.yield_batch() metric_list = [] + out_labels = self.model.metrics_names + callback_metrics = out_labels + ['val_' + n for n in out_labels] + callbacks = CallbackList(self.callbacks_list) + callbacks.set_model(self.model) + callbacks.set_params({ + 'epochs': epochs, + 'steps': self.steps_per_epoch, + 'verbose': self.verbose, + #'do_validation': do_validation, + 'do_validation': False, + 'metrics': callback_metrics, + }) + self.model.stop_training = False + callbacks.on_train_begin() + epoch_logs = {} for epoch in range(epochs): + callbacks.on_epoch_begin(epoch) logger.info("Fitting epoch {}".format(epoch)) metrics = None avg_metrics = None for batch_id in tqdm(range(self.steps_per_epoch)): x, y, w = next(batch_generator) + batch_logs = {} + batch_logs['batch'] = batch_id + batch_logs['size'] = len(x) + callbacks.on_batch_begin(batch_id, batch_logs) + # fit the classifier batch_metrics = self.model.train_on_batch( x, y, sample_weight=w ) + # fit the adversary self.model_adv.train_on_batch( x, y[1:], sample_weight=w[1:] ) + batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics)) if metrics is None: metrics = batch_metrics else: metrics = np.concatenate([metrics, batch_metrics]) avg_metrics = np.mean(metrics, axis=0) + outs = list(batch_metrics) + for l, o in zip(out_labels, outs): + batch_logs[l] = o + callbacks.on_batch_end(batch_id, batch_logs) metric_list.append(avg_metrics) + val_metrics = self.model.test_on_batch(*self.validation_data) + val_outs = list(val_metrics) + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + callbacks.on_epoch_end(epoch, epoch_logs) + if self.model.stop_training: + break + callbacks.on_train_end() return metric_list