diff --git a/toolkit.py b/toolkit.py index d596c9ad60e4db61d5037b549b1baf8052fc1e81..e424460e69de51bd585a21bede5d12fc6538fe8e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -338,6 +338,8 @@ class ClassificationProject(object): self._fields = None + self._mean_train_weight = None + @property def fields(self): @@ -741,6 +743,13 @@ class ClassificationProject(object): np.random.shuffle(self._scores_train) + @property + def mean_train_weight(self): + if self._mean_train_weight is None: + self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]) + return self._mean_train_weight + + @property def w_validation(self): "class weighted validation data weights" @@ -749,7 +758,7 @@ class ClassificationProject(object): self._w_validation = np.array(self.w_train[split_index:]) self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0] self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1] - return self._w_validation + return self._w_validation/self.mean_train_weight @property @@ -1501,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject): x_input = self.get_input_list(x_batch) yield (x_input, y_train[shuffled_idx[start:start+int(self.batch_size)]], - w_batch*np.array(self.class_weight)[y_batch.astype(int)]) + w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight) @property