From 0d931a419dd16e67596e35ebb20dc194141c2fd1 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 12 Jun 2018 09:45:28 +0200
Subject: [PATCH] shuffle training data before splitting in balance_dataset
 mode

---
 toolkit.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 82391fa..7f6cad0 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -605,7 +605,7 @@ class ClassificationProject(object):
 
     @property
     def w_validation(self):
-        "class weighted validation data"
+        "class weighted validation data weights"
         split_index = int((1-self.validation_split)*len(self.x_train))
         if self._w_validation is None:
             self._w_validation = np.array(self.w_train[split_index:])
@@ -616,13 +616,14 @@ class ClassificationProject(object):
 
     @property
     def class_weighted_validation_data(self):
+        "class weighted validation data. Attention: Shuffle training data before using this!"
         split_index = int((1-self.validation_split)*len(self.x_train))
         return self.x_train[split_index:], self.y_train[split_index:], self.w_validation
 
 
     @property
     def training_data(self):
-        "training data with validation data split off"
+        "training data with validation data split off. Attention: Shuffle training data before using this!"
         split_index = int((1-self.validation_split)*len(self.x_train))
         return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
 
@@ -686,6 +687,7 @@ class ClassificationProject(object):
                 logger.info("Interrupt training - continue with rest")
         else:
             try:
+                self.shuffle_training_data() # needed here too, in order to get correct validation data
                 self.is_training = True
                 labels, label_counts = np.unique(self.y_train, return_counts=True)
                 logger.info("Training on balanced batches")
-- 
GitLab