diff --git a/toolkit.py b/toolkit.py index 3343700ef7cec09a905d6f4ef66bbec85171dec4..8451292dde62c7e35680c6df575ba17fba32d39c 100755 --- a/toolkit.py +++ b/toolkit.py @@ -598,10 +598,10 @@ class ClassificationProject(object): np.random.set_state(rn_state) w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label]) # yield them batch wise - for start in range(0, len(x_train[y_train==class_label]), self.batch_size): - yield (x_train[y_train==class_label][start:start+self.batch_size], - y_train[y_train==class_label][start:start+self.batch_size], - w_train[y_train==class_label][start:start+self.batch_size]) + for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)): + yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)], + y_train[y_train==class_label][start:start+int(self.batch_size/2)], + w_train[y_train==class_label][start:start+int(self.batch_size/2)]) # restart