diff --git a/toolkit.py b/toolkit.py index 3bd7d127a23d0d5039c75fce86cf78816a894067..6f2d075acc8110b723bf37565ce39538e0bcdede 100755 --- a/toolkit.py +++ b/toolkit.py @@ -760,26 +760,28 @@ class ClassificationProject(object): return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] - def yield_batch(self, class_label): + def yield_single_class_batch(self, class_label): + """ + Generate batches of half batch size, containing only entries for the given class label. + The weights are multiplied by balanced_class_weight. + """ + x_train, y_train, w_train = self.training_data + class_idx = np.where(y_train==class_label)[0] while True: - x_train, y_train, w_train = self.training_data - # shuffle the entries for this class label - rn_state = np.random.get_state() - x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label]) - np.random.set_state(rn_state) - w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label]) + # shuffle the indices for this class label + shuffled_idx = np.random.permutation(class_idx) # yield them batch wise - for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)): - yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)], - y_train[y_train==class_label][start:start+int(self.batch_size/2)], - w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label]) - # restart + for start in range(0, len(shuffled_idx), int(self.batch_size/2)): + yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]], + y_train[shuffled_idx[start:start+int(self.batch_size/2)]], + w_train[shuffled_idx[start:start+int(self.batch_size/2)]]*self.balanced_class_weight[class_label]) def yield_balanced_batch(self): "generate batches with equal amounts of both classes" logcounter = 0 - for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)): + for batch_0, batch_1 in izip(self.yield_single_class_batch(0), + self.yield_single_class_batch(1)): if logcounter == 10: logger.debug("\rSumw sig*balanced_class_weight[1]: {}".format(np.sum(batch_1[2]))) logger.debug("\rSumw bkg*balanced_class_weight[0]: {}".format(np.sum(batch_0[2])))