Skip to content
Snippets Groups Projects
Commit 2e9b6f67 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

support for tensorboard

parent 489b934d
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.models import model_from_json
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
from keras.optimizers import SGD
import keras.optimizers
import matplotlib.pyplot as plt
......@@ -145,6 +145,10 @@ class ClassificationProject(object):
you change the format of the saved model weights it has to be of
the form "weights*.h5"
:param use_tensorboard: if True, use the tensorboard callback to write logs for tensorboard
:param tensorboard_opts: options for the TensorBoard callback
:param balance_dataset: if True, balance the dataset instead of
applying class weights. Only a fraction of the overrepresented
class will be used in each epoch, but different subsets of the
......@@ -212,6 +216,8 @@ class ClassificationProject(object):
earlystopping_opts=None,
use_modelcheckpoint=True,
modelcheckpoint_opts=None,
use_tensorboard=False,
tensorboard_opts=None,
random_seed=1234,
balance_dataset=False,
loss='binary_crossentropy'):
......@@ -261,6 +267,7 @@ class ClassificationProject(object):
self.optimizer = optimizer
self.use_earlystopping = use_earlystopping
self.use_modelcheckpoint = use_modelcheckpoint
self.use_tensorboard = use_tensorboard
if optimizer_opts is None:
optimizer_opts = dict()
self.optimizer_opts = optimizer_opts
......@@ -274,6 +281,11 @@ class ClassificationProject(object):
filepath="weights.h5"
)
self.modelcheckpoint_opts = modelcheckpoint_opts
self.tensorboard_opts = dict(
log_dir=os.path.join(self.project_dir, "tensorboard"),
)
if tensorboard_opts is not None:
self.tensorboard_opts.update(**tensorboard_opts)
self.random_seed = random_seed
self.balance_dataset = balance_dataset
self.loss = loss
......@@ -482,6 +494,8 @@ class ClassificationProject(object):
if not os.path.dirname(mc.filepath) == self.project_dir:
mc.filepath = os.path.join(self.project_dir, mc.filepath)
logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath))
if self.use_tensorboard:
self._callbacks_list.append(TensorBoard(**self.tensorboard_opts))
self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
return self._callbacks_list
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment