Skip to content
Snippets Groups Projects
Commit 6a2f9292 authored by Nikolai's avatar Nikolai
Browse files

try normalising weights

parent 46cb66fa
No related branches found
No related tags found
No related merge requests found
...@@ -338,6 +338,8 @@ class ClassificationProject(object): ...@@ -338,6 +338,8 @@ class ClassificationProject(object):
self._fields = None self._fields = None
self._mean_train_weight = None
@property @property
def fields(self): def fields(self):
...@@ -741,6 +743,13 @@ class ClassificationProject(object): ...@@ -741,6 +743,13 @@ class ClassificationProject(object):
np.random.shuffle(self._scores_train) np.random.shuffle(self._scores_train)
@property
def mean_train_weight(self):
if self._mean_train_weight is None:
self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)])
return self._mean_train_weight
@property @property
def w_validation(self): def w_validation(self):
"class weighted validation data weights" "class weighted validation data weights"
...@@ -749,7 +758,7 @@ class ClassificationProject(object): ...@@ -749,7 +758,7 @@ class ClassificationProject(object):
self._w_validation = np.array(self.w_train[split_index:]) self._w_validation = np.array(self.w_train[split_index:])
self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0] self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1] self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
return self._w_validation return self._w_validation/self.mean_train_weight
@property @property
...@@ -1501,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1501,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject):
x_input = self.get_input_list(x_batch) x_input = self.get_input_list(x_batch)
yield (x_input, yield (x_input,
y_train[shuffled_idx[start:start+int(self.batch_size)]], y_train[shuffled_idx[start:start+int(self.batch_size)]],
w_batch*np.array(self.class_weight)[y_batch.astype(int)]) w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight)
@property @property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment