diff --git a/toolkit.py b/toolkit.py
index 3bda85e220409c946ea7da305f475fb22ce04517..7199468c71cbbe1198e6ba0f9e49e00ae7cbe464 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -184,6 +184,8 @@ class ClassificationProject(object):
 
     :param loss: loss function name (or list of names in case of regression targets)
 
+    :param loss_weights: (optional) list of weights to weight the individual losses (for multiple targets)
+
     :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
 
     :param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background)
@@ -260,6 +262,7 @@ class ClassificationProject(object):
                         shuffle_seed=42,
                         balance_dataset=False,
                         loss='binary_crossentropy',
+                        loss_weights=None,
                         mask_value=None,
                         apply_class_weight=True,
                         normalize_weights=True,
@@ -346,6 +349,7 @@ class ClassificationProject(object):
         self.shuffle_seed = shuffle_seed
         self.balance_dataset = balance_dataset
         self.loss = loss
+        self.loss_weights = loss_weights
         if self.regression_branches and (not isinstance(self.loss, list)):
             self.loss = [self.loss]+["mean_squared_error"]*len(self.regression_branches)
 
@@ -759,6 +763,7 @@ class ClassificationProject(object):
         np.random.seed(self.random_seed)
         self._model.compile(optimizer=optimizer,
                             loss=self.loss,
+                            loss_weights=self.loss_weights,
                             weighted_metrics=['accuracy']
         )
         np.random.set_state(rn_state)