From 2e9b6f672ec58f5041031b0beacc035bf9b222d4 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Fri, 10 Aug 2018 11:30:23 +0200 Subject: [PATCH] support for tensorboard --- toolkit.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/toolkit.py b/toolkit.py index 646f09f..86765ad 100755 --- a/toolkit.py +++ b/toolkit.py @@ -32,7 +32,7 @@ from sklearn.metrics import roc_curve, auc from keras.models import Sequential from keras.layers import Dense, Dropout from keras.models import model_from_json -from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint +from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard from keras.optimizers import SGD import keras.optimizers import matplotlib.pyplot as plt @@ -145,6 +145,10 @@ class ClassificationProject(object): you change the format of the saved model weights it has to be of the form "weights*.h5" + :param use_tensorboard: if True, use the tensorboard callback to write logs for tensorboard + + :param tensorboard_opts: options for the TensorBoard callback + :param balance_dataset: if True, balance the dataset instead of applying class weights. Only a fraction of the overrepresented class will be used in each epoch, but different subsets of the @@ -212,6 +216,8 @@ class ClassificationProject(object): earlystopping_opts=None, use_modelcheckpoint=True, modelcheckpoint_opts=None, + use_tensorboard=False, + tensorboard_opts=None, random_seed=1234, balance_dataset=False, loss='binary_crossentropy'): @@ -261,6 +267,7 @@ class ClassificationProject(object): self.optimizer = optimizer self.use_earlystopping = use_earlystopping self.use_modelcheckpoint = use_modelcheckpoint + self.use_tensorboard = use_tensorboard if optimizer_opts is None: optimizer_opts = dict() self.optimizer_opts = optimizer_opts @@ -274,6 +281,11 @@ class ClassificationProject(object): filepath="weights.h5" ) self.modelcheckpoint_opts = modelcheckpoint_opts + self.tensorboard_opts = dict( + log_dir=os.path.join(self.project_dir, "tensorboard"), + ) + if tensorboard_opts is not None: + self.tensorboard_opts.update(**tensorboard_opts) self.random_seed = random_seed self.balance_dataset = balance_dataset self.loss = loss @@ -482,6 +494,8 @@ class ClassificationProject(object): if not os.path.dirname(mc.filepath) == self.project_dir: mc.filepath = os.path.join(self.project_dir, mc.filepath) logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath)) + if self.use_tensorboard: + self._callbacks_list.append(TensorBoard(**self.tensorboard_opts)) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list -- GitLab