Skip to content
Snippets Groups Projects
Commit 0a377f2a authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adding BaseLogger and History callbacks

parent 37280c34
No related branches found
No related tags found
No related merge requests found
......@@ -37,7 +37,7 @@ from sklearn.utils.extmath import stable_cumsum
from sklearn.model_selection import KFold
from keras.models import Sequential, Model, model_from_json
from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard, CallbackList, BaseLogger
from keras.optimizers import SGD
import keras.optimizers
from keras.utils.vis_utils import model_to_dot
......@@ -1604,7 +1604,7 @@ class ClassificationProject(object):
hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
return hist_dict
def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None):
def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None, loss_key="loss"):
"""
Plot the value of the loss function for each epoch
......@@ -1616,14 +1616,14 @@ class ClassificationProject(object):
else:
hist_dict = self.history.history
if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict):
if (not loss_key in hist_dict) or (not 'val_'+loss_key in hist_dict):
logger.warning("No previous history found for plotting, try global history")
hist_dict = self.csv_hist
logger.info("Plot losses")
plt.plot(hist_dict['loss'])
plt.plot(hist_dict['val_loss'])
plt.ylabel('loss')
plt.plot(hist_dict[loss_key])
plt.plot(hist_dict['val_'+loss_key])
plt.ylabel(loss_key)
plt.xlabel('epoch')
plt.legend(['training data','validation data'], loc='upper left')
if log:
......@@ -2219,7 +2219,7 @@ class ClassificationProjectDecorr(ClassificationProject):
return self._model_adv
def train(self, epochs=10):
def train(self, epochs=10, skip_checkpoint=False):
"""
Train classifier and adversary concurrently. Most of the garbage in this
code block is just organising stuff to get all the keras callbacks
......@@ -2227,17 +2227,20 @@ class ClassificationProjectDecorr(ClassificationProject):
"""
batch_generator = self.yield_batch()
metric_list = []
out_labels = self.model.metrics_names
self.model.history = History()
callback_metrics = out_labels + ['val_' + n for n in out_labels]
callbacks = CallbackList(self.callbacks_list)
callbacks = CallbackList(
[BaseLogger()]
+ self.callbacks_list
+ [self.model.history])
callbacks.set_model(self.model)
callbacks.set_params({
'epochs': epochs,
'steps': self.steps_per_epoch,
'verbose': self.verbose,
#'do_validation': do_validation,
'do_validation': False,
'do_validation': True,
'metrics': callback_metrics,
})
self.model.stop_training = False
......@@ -2264,27 +2267,23 @@ class ClassificationProjectDecorr(ClassificationProject):
self.model_adv.train_on_batch(
x, y[1:], sample_weight=w[1:]
)
batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics))
if metrics is None:
metrics = batch_metrics
else:
metrics = np.concatenate([metrics, batch_metrics])
avg_metrics = np.mean(metrics, axis=0)
outs = list(batch_metrics)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
batch_logs[l] = float(o)
callbacks.on_batch_end(batch_id, batch_logs)
metric_list.append(avg_metrics)
val_metrics = self.model.test_on_batch(*self.validation_data)
val_outs = list(val_metrics)
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
epoch_logs['val_' + l] = float(o)
callbacks.on_epoch_end(epoch, epoch_logs)
if self.model.stop_training:
break
callbacks.on_train_end()
return metric_list
if not skip_checkpoint:
self.checkpoint_model()
return self.model.history
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment