From 2e9b6f672ec58f5041031b0beacc035bf9b222d4 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Fri, 10 Aug 2018 11:30:23 +0200
Subject: [PATCH] support for tensorboard

---
 toolkit.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/toolkit.py b/toolkit.py
index 646f09f..86765ad 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -32,7 +32,7 @@ from sklearn.metrics import roc_curve, auc
 from keras.models import Sequential
 from keras.layers import Dense, Dropout
 from keras.models import model_from_json
-from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint
+from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
 from keras.optimizers import SGD
 import keras.optimizers
 import matplotlib.pyplot as plt
@@ -145,6 +145,10 @@ class ClassificationProject(object):
                                  you change the format of the saved model weights it has to be of
                                  the form "weights*.h5"
 
+    :param use_tensorboard: if True, use the tensorboard callback to write logs for tensorboard
+
+    :param tensorboard_opts: options for the TensorBoard callback
+
     :param balance_dataset: if True, balance the dataset instead of
                             applying class weights. Only a fraction of the overrepresented
                             class will be used in each epoch, but different subsets of the
@@ -212,6 +216,8 @@ class ClassificationProject(object):
                         earlystopping_opts=None,
                         use_modelcheckpoint=True,
                         modelcheckpoint_opts=None,
+                        use_tensorboard=False,
+                        tensorboard_opts=None,
                         random_seed=1234,
                         balance_dataset=False,
                         loss='binary_crossentropy'):
@@ -261,6 +267,7 @@ class ClassificationProject(object):
         self.optimizer = optimizer
         self.use_earlystopping = use_earlystopping
         self.use_modelcheckpoint = use_modelcheckpoint
+        self.use_tensorboard = use_tensorboard
         if optimizer_opts is None:
             optimizer_opts = dict()
         self.optimizer_opts = optimizer_opts
@@ -274,6 +281,11 @@ class ClassificationProject(object):
                 filepath="weights.h5"
             )
         self.modelcheckpoint_opts = modelcheckpoint_opts
+        self.tensorboard_opts = dict(
+            log_dir=os.path.join(self.project_dir, "tensorboard"),
+        )
+        if tensorboard_opts is not None:
+            self.tensorboard_opts.update(**tensorboard_opts)
         self.random_seed = random_seed
         self.balance_dataset = balance_dataset
         self.loss = loss
@@ -482,6 +494,8 @@ class ClassificationProject(object):
             if not os.path.dirname(mc.filepath) == self.project_dir:
                 mc.filepath = os.path.join(self.project_dir, mc.filepath)
                 logger.debug("Prepending project dir to ModelCheckpoint filepath: {}".format(mc.filepath))
+        if self.use_tensorboard:
+            self._callbacks_list.append(TensorBoard(**self.tensorboard_opts))
         self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
         return self._callbacks_list
 
-- 
GitLab