diff --git a/toolkit.py b/toolkit.py index 82391fa1bb60131cb70682d2a7453616de37ff6e..7f6cad06fcbc3af2f3504e5da1913511f2d90475 100755 --- a/toolkit.py +++ b/toolkit.py @@ -605,7 +605,7 @@ class ClassificationProject(object): @property def w_validation(self): - "class weighted validation data" + "class weighted validation data weights" split_index = int((1-self.validation_split)*len(self.x_train)) if self._w_validation is None: self._w_validation = np.array(self.w_train[split_index:]) @@ -616,13 +616,14 @@ class ClassificationProject(object): @property def class_weighted_validation_data(self): + "class weighted validation data. Attention: Shuffle training data before using this!" split_index = int((1-self.validation_split)*len(self.x_train)) return self.x_train[split_index:], self.y_train[split_index:], self.w_validation @property def training_data(self): - "training data with validation data split off" + "training data with validation data split off. Attention: Shuffle training data before using this!" split_index = int((1-self.validation_split)*len(self.x_train)) return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] @@ -686,6 +687,7 @@ class ClassificationProject(object): logger.info("Interrupt training - continue with rest") else: try: + self.shuffle_training_data() # needed here too, in order to get correct validation data self.is_training = True labels, label_counts = np.unique(self.y_train, return_counts=True) logger.info("Training on balanced batches")