diff --git a/toolkit.py b/toolkit.py
index 5a312a204a434545032310cf87ea7d9fda706914..e801ef6fc3f9d41b6ad42f9dc2eb74eb2d79a4ae 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -150,6 +150,8 @@ class ClassificationProject(object):
                         random data is also used for shuffling the training data, so results may vary still. To
                         produce consistent results, set the numpy random seed before training.
 
+    :param loss: loss function name
+
     """
 
 
@@ -205,7 +207,8 @@ class ClassificationProject(object):
                         use_modelcheckpoint=True,
                         modelcheckpoint_opts=None,
                         random_seed=1234,
-                        balance_dataset=False):
+                        balance_dataset=False,
+                        loss='binary_crossentropy'):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -253,6 +256,7 @@ class ClassificationProject(object):
         self.modelcheckpoint_opts = modelcheckpoint_opts
         self.random_seed = random_seed
         self.balance_dataset = balance_dataset
+        self.loss = loss
 
         self.s_train = None
         self.b_train = None
@@ -562,7 +566,7 @@ class ClassificationProject(object):
             rn_state = np.random.get_state()
             np.random.seed(self.random_seed)
             self._model.compile(optimizer=optimizer,
-                                loss='binary_crossentropy',
+                                loss=self.loss,
                                 metrics=['accuracy'])
             np.random.set_state(rn_state)
             if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
@@ -1031,7 +1035,7 @@ class ClassificationProject(object):
         plt.plot(hist_dict['val_loss'])
         plt.ylabel('loss')
         plt.xlabel('epoch')
-        plt.legend(['train','test'], loc='upper left')
+        plt.legend(['training data','validation data'], loc='upper left')
         if log:
             plt.yscale("log")
         if xlim is not None:
@@ -1065,7 +1069,7 @@ class ClassificationProject(object):
         plt.title('model accuracy')
         plt.ylabel('accuracy')
         plt.xlabel('epoch')
-        plt.legend(['train', 'test'], loc='upper left')
+        plt.legend(['training data', 'validation data'], loc='upper left')
         if log:
             plt.yscale("log")
         plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))