From d9f78c7040de8fe2f5b667f71755c3ee1d297cba Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Wed, 25 Jul 2018 16:39:54 +0200
Subject: [PATCH] loss function configurable

---
 toolkit.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 5a312a2..e801ef6 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -150,6 +150,8 @@ class ClassificationProject(object):
                         random data is also used for shuffling the training data, so results may vary still. To
                         produce consistent results, set the numpy random seed before training.
 
+    :param loss: loss function name
+
     """
 
 
@@ -205,7 +207,8 @@ class ClassificationProject(object):
                         use_modelcheckpoint=True,
                         modelcheckpoint_opts=None,
                         random_seed=1234,
-                        balance_dataset=False):
+                        balance_dataset=False,
+                        loss='binary_crossentropy'):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -253,6 +256,7 @@ class ClassificationProject(object):
         self.modelcheckpoint_opts = modelcheckpoint_opts
         self.random_seed = random_seed
         self.balance_dataset = balance_dataset
+        self.loss = loss
 
         self.s_train = None
         self.b_train = None
@@ -562,7 +566,7 @@ class ClassificationProject(object):
             rn_state = np.random.get_state()
             np.random.seed(self.random_seed)
             self._model.compile(optimizer=optimizer,
-                                loss='binary_crossentropy',
+                                loss=self.loss,
                                 metrics=['accuracy'])
             np.random.set_state(rn_state)
             if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
@@ -1031,7 +1035,7 @@ class ClassificationProject(object):
         plt.plot(hist_dict['val_loss'])
         plt.ylabel('loss')
         plt.xlabel('epoch')
-        plt.legend(['train','test'], loc='upper left')
+        plt.legend(['training data','validation data'], loc='upper left')
         if log:
             plt.yscale("log")
         if xlim is not None:
@@ -1065,7 +1069,7 @@ class ClassificationProject(object):
         plt.title('model accuracy')
         plt.ylabel('accuracy')
         plt.xlabel('epoch')
-        plt.legend(['train', 'test'], loc='upper left')
+        plt.legend(['training data', 'validation data'], loc='upper left')
         if log:
             plt.yscale("log")
         plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
-- 
GitLab