Skip to content
Snippets Groups Projects
Commit 3c3d6c50 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

weight individual losses

parent e67e13f2
No related branches found
No related tags found
No related merge requests found
......@@ -184,6 +184,8 @@ class ClassificationProject(object):
:param loss: loss function name (or list of names in case of regression targets)
:param loss_weights: (optional) list of weights to weight the individual losses (for multiple targets)
:param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
:param apply_class_weight: apply a weight that scales the events such that sumw(signal) = sumw(background)
......@@ -260,6 +262,7 @@ class ClassificationProject(object):
shuffle_seed=42,
balance_dataset=False,
loss='binary_crossentropy',
loss_weights=None,
mask_value=None,
apply_class_weight=True,
normalize_weights=True,
......@@ -346,6 +349,7 @@ class ClassificationProject(object):
self.shuffle_seed = shuffle_seed
self.balance_dataset = balance_dataset
self.loss = loss
self.loss_weights = loss_weights
if self.regression_branches and (not isinstance(self.loss, list)):
self.loss = [self.loss]+["mean_squared_error"]*len(self.regression_branches)
......@@ -759,6 +763,7 @@ class ClassificationProject(object):
np.random.seed(self.random_seed)
self._model.compile(optimizer=optimizer,
loss=self.loss,
loss_weights=self.loss_weights,
weighted_metrics=['accuracy']
)
np.random.set_state(rn_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment