Skip to content
Snippets Groups Projects
Commit 3bc3c4f1 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Adding appropriate class weights for balance_dataset mode

parent 3235c3ea
No related branches found
No related tags found
No related merge requests found
...@@ -230,6 +230,7 @@ class ClassificationProject(object): ...@@ -230,6 +230,7 @@ class ClassificationProject(object):
self._scaler = None self._scaler = None
self._class_weight = None self._class_weight = None
self._balanced_class_weight = None
self._model = None self._model = None
self._history = None self._history = None
self._callbacks_list = [] self._callbacks_list = []
...@@ -534,10 +535,29 @@ class ClassificationProject(object): ...@@ -534,10 +535,29 @@ class ClassificationProject(object):
sumw_bkg = np.sum(self.w_train[self.y_train == 0]) sumw_bkg = np.sum(self.w_train[self.y_train == 0])
sumw_sig = np.sum(self.w_train[self.y_train == 1]) sumw_sig = np.sum(self.w_train[self.y_train == 1])
self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)] self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
logger.debug("Calculated class_weight: {}".format(self._class_weight)) logger.debug("Calculated class_weight: {}".format(self._class_weight))
return self._class_weight return self._class_weight
@property
def balanced_class_weight(self):
"""
Class weight for the balance_dataset method
Since we have equal number of signal and background events in
each batch, we need to balance the ratio of sum of weights per
event with class weights
"""
if self._balanced_class_weight is None:
sumw_bkg = np.sum(self.w_train[self.y_train == 0])
sumw_sig = np.sum(self.w_train[self.y_train == 1])
# use sumw *per event* in this case
sumw_bkg /= len(self.w_train[self.y_train == 0])
sumw_sig /= len(self.w_train[self.y_train == 1])
self._balanced_class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
logger.debug("Calculated balanced_class_weight: {}".format(self._balanced_class_weight))
return self._balanced_class_weight
def load(self, reload=False): def load(self, reload=False):
"Load all data needed for plotting and training" "Load all data needed for plotting and training"
...@@ -601,13 +621,19 @@ class ClassificationProject(object): ...@@ -601,13 +621,19 @@ class ClassificationProject(object):
for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)): for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)], yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
y_train[y_train==class_label][start:start+int(self.batch_size/2)], y_train[y_train==class_label][start:start+int(self.batch_size/2)],
w_train[y_train==class_label][start:start+int(self.batch_size/2)]) w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label])
# restart # restart
def yield_balanced_batch(self): def yield_balanced_batch(self):
"generate batches with equal amounts of both classes" "generate batches with equal amounts of both classes"
logcounter = 0
for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)): for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)):
if logcounter == 10:
logger.debug("\rSumw sig*balanced_class_weight[1]: {}".format(np.sum(batch_1[2])))
logger.debug("\rSumw bkg*balanced_class_weight[0]: {}".format(np.sum(batch_0[2])))
logcounter = 0
logcounter += 1
yield (np.concatenate((batch_0[0], batch_1[0])), yield (np.concatenate((batch_0[0], batch_1[0])),
np.concatenate((batch_0[1], batch_1[1])), np.concatenate((batch_0[1], batch_1[1])),
np.concatenate((batch_0[2], batch_1[2]))) np.concatenate((batch_0[2], batch_1[2])))
...@@ -645,6 +671,7 @@ class ClassificationProject(object): ...@@ -645,6 +671,7 @@ class ClassificationProject(object):
self.is_training = True self.is_training = True
labels, label_counts = np.unique(self.y_train, return_counts=True) labels, label_counts = np.unique(self.y_train, return_counts=True)
logger.info("Training on balanced batches") logger.info("Training on balanced batches")
# note: the batches have balanced_class_weight already applied
self.model.fit_generator(self.yield_balanced_batch(), self.model.fit_generator(self.yield_balanced_batch(),
steps_per_epoch=int(min(label_counts)/self.batch_size), steps_per_epoch=int(min(label_counts)/self.batch_size),
epochs=epochs, epochs=epochs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment