diff --git a/toolkit.py b/toolkit.py
index 3bd7d127a23d0d5039c75fce86cf78816a894067..6f2d075acc8110b723bf37565ce39538e0bcdede 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -760,26 +760,28 @@ class ClassificationProject(object):
         return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
 
 
-    def yield_batch(self, class_label):
+    def yield_single_class_batch(self, class_label):
+        """
+        Generate batches of half batch size, containing only entries for the given class label.
+        The weights are multiplied by balanced_class_weight.
+        """
+        x_train, y_train, w_train = self.training_data
+        class_idx = np.where(y_train==class_label)[0]
         while True:
-            x_train, y_train, w_train = self.training_data
-            # shuffle the entries for this class label
-            rn_state = np.random.get_state()
-            x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
-            np.random.set_state(rn_state)
-            w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
+            # shuffle the indices for this class label
+            shuffled_idx = np.random.permutation(class_idx)
             # yield them batch wise
-            for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
-                yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
-                       y_train[y_train==class_label][start:start+int(self.batch_size/2)],
-                       w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label])
-            # restart
+            for start in range(0, len(shuffled_idx), int(self.batch_size/2)):
+                yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]],
+                       y_train[shuffled_idx[start:start+int(self.batch_size/2)]],
+                       w_train[shuffled_idx[start:start+int(self.batch_size/2)]]*self.balanced_class_weight[class_label])
 
 
     def yield_balanced_batch(self):
         "generate batches with equal amounts of both classes"
         logcounter = 0
-        for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)):
+        for batch_0, batch_1 in izip(self.yield_single_class_batch(0),
+                                     self.yield_single_class_batch(1)):
             if logcounter == 10:
                 logger.debug("\rSumw sig*balanced_class_weight[1]: {}".format(np.sum(batch_1[2])))
                 logger.debug("\rSumw bkg*balanced_class_weight[0]: {}".format(np.sum(batch_0[2])))