Skip to content
Snippets Groups Projects
Commit 0d931a41 authored by Nikolai's avatar Nikolai
Browse files

shuffle training data before splitting in balance_dataset mode

parent 58e11537
No related branches found
No related tags found
No related merge requests found
...@@ -605,7 +605,7 @@ class ClassificationProject(object): ...@@ -605,7 +605,7 @@ class ClassificationProject(object):
@property @property
def w_validation(self): def w_validation(self):
"class weighted validation data" "class weighted validation data weights"
split_index = int((1-self.validation_split)*len(self.x_train)) split_index = int((1-self.validation_split)*len(self.x_train))
if self._w_validation is None: if self._w_validation is None:
self._w_validation = np.array(self.w_train[split_index:]) self._w_validation = np.array(self.w_train[split_index:])
...@@ -616,13 +616,14 @@ class ClassificationProject(object): ...@@ -616,13 +616,14 @@ class ClassificationProject(object):
@property @property
def class_weighted_validation_data(self): def class_weighted_validation_data(self):
"class weighted validation data. Attention: Shuffle training data before using this!"
split_index = int((1-self.validation_split)*len(self.x_train)) split_index = int((1-self.validation_split)*len(self.x_train))
return self.x_train[split_index:], self.y_train[split_index:], self.w_validation return self.x_train[split_index:], self.y_train[split_index:], self.w_validation
@property @property
def training_data(self): def training_data(self):
"training data with validation data split off" "training data with validation data split off. Attention: Shuffle training data before using this!"
split_index = int((1-self.validation_split)*len(self.x_train)) split_index = int((1-self.validation_split)*len(self.x_train))
return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
...@@ -686,6 +687,7 @@ class ClassificationProject(object): ...@@ -686,6 +687,7 @@ class ClassificationProject(object):
logger.info("Interrupt training - continue with rest") logger.info("Interrupt training - continue with rest")
else: else:
try: try:
self.shuffle_training_data() # needed here too, in order to get correct validation data
self.is_training = True self.is_training = True
labels, label_counts = np.unique(self.y_train, return_counts=True) labels, label_counts = np.unique(self.y_train, return_counts=True)
logger.info("Training on balanced batches") logger.info("Training on balanced batches")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment