Skip to content
Snippets Groups Projects
Commit 5b541afd authored by Nikolai's avatar Nikolai
Browse files

making balanced training more efficient

parent 71b89893
No related branches found
No related tags found
No related merge requests found
...@@ -760,26 +760,28 @@ class ClassificationProject(object): ...@@ -760,26 +760,28 @@ class ClassificationProject(object):
return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
def yield_batch(self, class_label): def yield_single_class_batch(self, class_label):
"""
Generate batches of half batch size, containing only entries for the given class label.
The weights are multiplied by balanced_class_weight.
"""
x_train, y_train, w_train = self.training_data
class_idx = np.where(y_train==class_label)[0]
while True: while True:
x_train, y_train, w_train = self.training_data # shuffle the indices for this class label
# shuffle the entries for this class label shuffled_idx = np.random.permutation(class_idx)
rn_state = np.random.get_state()
x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
np.random.set_state(rn_state)
w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
# yield them batch wise # yield them batch wise
for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)): for start in range(0, len(shuffled_idx), int(self.batch_size/2)):
yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)], yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]],
y_train[y_train==class_label][start:start+int(self.batch_size/2)], y_train[shuffled_idx[start:start+int(self.batch_size/2)]],
w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label]) w_train[shuffled_idx[start:start+int(self.batch_size/2)]]*self.balanced_class_weight[class_label])
# restart
def yield_balanced_batch(self): def yield_balanced_batch(self):
"generate batches with equal amounts of both classes" "generate batches with equal amounts of both classes"
logcounter = 0 logcounter = 0
for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)): for batch_0, batch_1 in izip(self.yield_single_class_batch(0),
self.yield_single_class_batch(1)):
if logcounter == 10: if logcounter == 10:
logger.debug("\rSumw sig*balanced_class_weight[1]: {}".format(np.sum(batch_1[2]))) logger.debug("\rSumw sig*balanced_class_weight[1]: {}".format(np.sum(batch_1[2])))
logger.debug("\rSumw bkg*balanced_class_weight[0]: {}".format(np.sum(batch_0[2]))) logger.debug("\rSumw bkg*balanced_class_weight[0]: {}".format(np.sum(batch_0[2])))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment