From da18e449f25b2e6e0542c73190c8405a89222c36 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 17 May 2018 16:36:36 +0200 Subject: [PATCH] correct batch sizes for balanced training --- toolkit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/toolkit.py b/toolkit.py index 3343700..8451292 100755 --- a/toolkit.py +++ b/toolkit.py @@ -598,10 +598,10 @@ class ClassificationProject(object): np.random.set_state(rn_state) w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label]) # yield them batch wise - for start in range(0, len(x_train[y_train==class_label]), self.batch_size): - yield (x_train[y_train==class_label][start:start+self.batch_size], - y_train[y_train==class_label][start:start+self.batch_size], - w_train[y_train==class_label][start:start+self.batch_size]) + for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)): + yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)], + y_train[y_train==class_label][start:start+int(self.batch_size/2)], + w_train[y_train==class_label][start:start+int(self.batch_size/2)]) # restart -- GitLab