From da18e449f25b2e6e0542c73190c8405a89222c36 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Thu, 17 May 2018 16:36:36 +0200
Subject: [PATCH] correct batch sizes for balanced training

---
 toolkit.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 3343700..8451292 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -598,10 +598,10 @@ class ClassificationProject(object):
             np.random.set_state(rn_state)
             w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
             # yield them batch wise
-            for start in range(0, len(x_train[y_train==class_label]), self.batch_size):
-                yield (x_train[y_train==class_label][start:start+self.batch_size],
-                       y_train[y_train==class_label][start:start+self.batch_size],
-                       w_train[y_train==class_label][start:start+self.batch_size])
+            for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
+                yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
+                       y_train[y_train==class_label][start:start+int(self.batch_size/2)],
+                       w_train[y_train==class_label][start:start+int(self.batch_size/2)])
             # restart
 
 
-- 
GitLab