From 0d931a419dd16e67596e35ebb20dc194141c2fd1 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Tue, 12 Jun 2018 09:45:28 +0200 Subject: [PATCH] shuffle training data before splitting in balance_dataset mode --- toolkit.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/toolkit.py b/toolkit.py index 82391fa..7f6cad0 100755 --- a/toolkit.py +++ b/toolkit.py @@ -605,7 +605,7 @@ class ClassificationProject(object): @property def w_validation(self): - "class weighted validation data" + "class weighted validation data weights" split_index = int((1-self.validation_split)*len(self.x_train)) if self._w_validation is None: self._w_validation = np.array(self.w_train[split_index:]) @@ -616,13 +616,14 @@ class ClassificationProject(object): @property def class_weighted_validation_data(self): + "class weighted validation data. Attention: Shuffle training data before using this!" split_index = int((1-self.validation_split)*len(self.x_train)) return self.x_train[split_index:], self.y_train[split_index:], self.w_validation @property def training_data(self): - "training data with validation data split off" + "training data with validation data split off. Attention: Shuffle training data before using this!" split_index = int((1-self.validation_split)*len(self.x_train)) return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] @@ -686,6 +687,7 @@ class ClassificationProject(object): logger.info("Interrupt training - continue with rest") else: try: + self.shuffle_training_data() # needed here too, in order to get correct validation data self.is_training = True labels, label_counts = np.unique(self.y_train, return_counts=True) logger.info("Training on balanced batches") -- GitLab