diff --git a/toolkit.py b/toolkit.py index 3bda85e220409c946ea7da305f475fb22ce04517..7199468c71cbbe1198e6ba0f9e49e00ae7cbe464 100755 --- a/toolkit.py +++ b/toolkit.py @@ -184,6 +184,8 @@ class ClassificationProject(object): :param loss: loss function name (or list of names in case of regression targets) + :param loss_weights: (optional) list of weights to weight the individual losses (for multiple targets) + :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets) :param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background) @@ -260,6 +262,7 @@ class ClassificationProject(object): shuffle_seed=42, balance_dataset=False, loss='binary_crossentropy', + loss_weights=None, mask_value=None, apply_class_weight=True, normalize_weights=True, @@ -346,6 +349,7 @@ class ClassificationProject(object): self.shuffle_seed = shuffle_seed self.balance_dataset = balance_dataset self.loss = loss + self.loss_weights = loss_weights if self.regression_branches and (not isinstance(self.loss, list)): self.loss = [self.loss]+["mean_squared_error"]*len(self.regression_branches) @@ -759,6 +763,7 @@ class ClassificationProject(object): np.random.seed(self.random_seed) self._model.compile(optimizer=optimizer, loss=self.loss, + loss_weights=self.loss_weights, weighted_metrics=['accuracy'] ) np.random.set_state(rn_state)