Skip to content
Snippets Groups Projects
Commit 5e3ae025 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Attempt to fix shuffling issues

parent aaaf823c
No related branches found
No related tags found
No related merge requests found
......@@ -165,9 +165,11 @@ class ClassificationProject(object):
class will be used in each epoch, but different subsets of the
overrepresented class will be used in each epoch.
:param random_seed: use this seed value when initialising the model and produce consistent results. Note:
random data is also used for shuffling the training data, so results may vary still. To
produce consistent results, set the numpy random seed before training.
:param random_seed: use this seed value when initialising the model and produce consistent results.
:param shuffle_seed: use this seed for shuffling the training data
the first time. This seed (increased by one) is used again before
training when keras shuffling is used.
:param loss: loss function name
......@@ -238,6 +240,7 @@ class ClassificationProject(object):
use_tensorboard=False,
tensorboard_opts=None,
random_seed=1234,
shuffle_seed=42,
balance_dataset=False,
loss='binary_crossentropy',
mask_value=None,
......@@ -311,6 +314,7 @@ class ClassificationProject(object):
if tensorboard_opts is not None:
self.tensorboard_opts.update(**tensorboard_opts)
self.random_seed = random_seed
self.shuffle_seed = shuffle_seed
self.balance_dataset = balance_dataset
self.loss = loss
self.mask_value = mask_value
......@@ -789,6 +793,13 @@ class ClassificationProject(object):
def shuffle_training_data(self):
np.random.seed(self.shuffle_seed)
# touch property to make sure it is created
# before shuffling the inputs!
self.w_train_tot
rn_state = np.random.get_state()
np.random.shuffle(self.x_train)
np.random.set_state(rn_state)
......@@ -811,6 +822,8 @@ class ClassificationProject(object):
class_weight = self.class_weight
else:
class_weight = self.balanced_class_weight
if not self.data_loaded:
self._w_train_tot = None
if self._w_train_tot is None:
if self.apply_class_weight:
self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
......@@ -886,6 +899,7 @@ class ClassificationProject(object):
if not self.balance_dataset:
try:
self.is_training = True
np.random.seed(self.shuffle_seed+1) # since we use keras shuffling here
self.model.fit(self.x_train,
# the reshape might be unnescessary here
self.y_train.reshape(-1, 1),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment