From 3c3d6c505b713d69ebe62924db3899d2c706e1ad Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 19 Nov 2018 17:59:02 +0100 Subject: [PATCH] weight individual losses --- toolkit.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/toolkit.py b/toolkit.py index 3bda85e..7199468 100755 --- a/toolkit.py +++ b/toolkit.py @@ -184,6 +184,8 @@ class ClassificationProject(object): :param loss: loss function name (or list of names in case of regression targets) + :param loss_weights: (optional) list of weights to weight the individual losses (for multiple targets) + :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets) :param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background) @@ -260,6 +262,7 @@ class ClassificationProject(object): shuffle_seed=42, balance_dataset=False, loss='binary_crossentropy', + loss_weights=None, mask_value=None, apply_class_weight=True, normalize_weights=True, @@ -346,6 +349,7 @@ class ClassificationProject(object): self.shuffle_seed = shuffle_seed self.balance_dataset = balance_dataset self.loss = loss + self.loss_weights = loss_weights if self.regression_branches and (not isinstance(self.loss, list)): self.loss = [self.loss]+["mean_squared_error"]*len(self.regression_branches) @@ -759,6 +763,7 @@ class ClassificationProject(object): np.random.seed(self.random_seed) self._model.compile(optimizer=optimizer, loss=self.loss, + loss_weights=self.loss_weights, weighted_metrics=['accuracy'] ) np.random.set_state(rn_state) -- GitLab