Skip to content
Snippets Groups Projects
Commit 6c1b3d00 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

keyboard interrupt for adversarial training

parent b049f3d7
No related branches found
No related tags found
No related merge requests found
...@@ -2280,38 +2280,41 @@ class ClassificationProjectDecorr(ClassificationProject): ...@@ -2280,38 +2280,41 @@ class ClassificationProjectDecorr(ClassificationProject):
self.model.stop_training = False self.model.stop_training = False
callbacks.on_train_begin() callbacks.on_train_begin()
epoch_logs = {} epoch_logs = {}
for epoch in range(epochs): try:
callbacks.on_epoch_begin(epoch) for epoch in range(epochs):
logger.info("Fitting epoch {}".format(epoch)) callbacks.on_epoch_begin(epoch)
metrics = None logger.info("Fitting epoch {}".format(epoch))
avg_metrics = None metrics = None
for batch_id in tqdm(range(self.steps_per_epoch)): avg_metrics = None
x, y, w = next(batch_generator) for batch_id in tqdm(range(self.steps_per_epoch)):
batch_logs = {} x, y, w = next(batch_generator)
batch_logs['batch'] = batch_id batch_logs = {}
batch_logs['size'] = len(x) batch_logs['batch'] = batch_id
callbacks.on_batch_begin(batch_id, batch_logs) batch_logs['size'] = len(x)
callbacks.on_batch_begin(batch_id, batch_logs)
# fit the classifier
batch_metrics = self.model.train_on_batch( # fit the classifier
x, y, sample_weight=w batch_metrics = self.model.train_on_batch(
) x, y, sample_weight=w
)
# fit the adversary # fit the adversary
self.model_adv.train_on_batch( self.model_adv.train_on_batch(
x, y[1:], sample_weight=w[1:] x, y[1:], sample_weight=w[1:]
) )
outs = list(batch_metrics) outs = list(batch_metrics)
for l, o in zip(out_labels, outs): for l, o in zip(out_labels, outs):
batch_logs[l] = float(o) batch_logs[l] = float(o)
callbacks.on_batch_end(batch_id, batch_logs) callbacks.on_batch_end(batch_id, batch_logs)
val_metrics = self.model.test_on_batch(*self.validation_data) val_metrics = self.model.test_on_batch(*self.validation_data)
val_outs = list(val_metrics) val_outs = list(val_metrics)
for l, o in zip(out_labels, val_outs): for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = float(o) epoch_logs['val_' + l] = float(o)
callbacks.on_epoch_end(epoch, epoch_logs) callbacks.on_epoch_end(epoch, epoch_logs)
if self.model.stop_training: if self.model.stop_training:
break break
except KeyboardInterrupt:
pass
callbacks.on_train_end() callbacks.on_train_end()
if not skip_checkpoint: if not skip_checkpoint:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment