diff --git a/toolkit.py b/toolkit.py index fb0a382341040d08e1de9d58c8882943ee2dd145..d0f44ce251f16a17a026f310bcc404ed54b23e9f 100755 --- a/toolkit.py +++ b/toolkit.py @@ -165,9 +165,11 @@ class ClassificationProject(object): class will be used in each epoch, but different subsets of the overrepresented class will be used in each epoch. - :param random_seed: use this seed value when initialising the model and produce consistent results. Note: - random data is also used for shuffling the training data, so results may vary still. To - produce consistent results, set the numpy random seed before training. + :param random_seed: use this seed value when initialising the model and produce consistent results. + + :param shuffle_seed: use this seed for shuffling the training data + the first time. This seed (increased by one) is used again before + training when keras shuffling is used. :param loss: loss function name @@ -238,6 +240,7 @@ class ClassificationProject(object): use_tensorboard=False, tensorboard_opts=None, random_seed=1234, + shuffle_seed=42, balance_dataset=False, loss='binary_crossentropy', mask_value=None, @@ -311,6 +314,7 @@ class ClassificationProject(object): if tensorboard_opts is not None: self.tensorboard_opts.update(**tensorboard_opts) self.random_seed = random_seed + self.shuffle_seed = shuffle_seed self.balance_dataset = balance_dataset self.loss = loss self.mask_value = mask_value @@ -789,6 +793,13 @@ class ClassificationProject(object): def shuffle_training_data(self): + + np.random.seed(self.shuffle_seed) + + # touch property to make sure it is created + # before shuffling the inputs! + self.w_train_tot + rn_state = np.random.get_state() np.random.shuffle(self.x_train) np.random.set_state(rn_state) @@ -811,6 +822,8 @@ class ClassificationProject(object): class_weight = self.class_weight else: class_weight = self.balanced_class_weight + if not self.data_loaded: + self._w_train_tot = None if self._w_train_tot is None: if self.apply_class_weight: self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] @@ -886,6 +899,7 @@ class ClassificationProject(object): if not self.balance_dataset: try: self.is_training = True + np.random.seed(self.shuffle_seed+1) # since we use keras shuffling here self.model.fit(self.x_train, # the reshape might be unnescessary here self.y_train.reshape(-1, 1),