Skip to content
Snippets Groups Projects
Commit 6a2f9292 authored by Nikolai's avatar Nikolai
Browse files

try normalising weights

parent 46cb66fa
No related branches found
No related tags found
No related merge requests found
......@@ -338,6 +338,8 @@ class ClassificationProject(object):
self._fields = None
self._mean_train_weight = None
@property
def fields(self):
......@@ -741,6 +743,13 @@ class ClassificationProject(object):
np.random.shuffle(self._scores_train)
@property
def mean_train_weight(self):
if self._mean_train_weight is None:
self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)])
return self._mean_train_weight
@property
def w_validation(self):
"class weighted validation data weights"
......@@ -749,7 +758,7 @@ class ClassificationProject(object):
self._w_validation = np.array(self.w_train[split_index:])
self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
return self._w_validation
return self._w_validation/self.mean_train_weight
@property
......@@ -1501,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject):
x_input = self.get_input_list(x_batch)
yield (x_input,
y_train[shuffled_idx[start:start+int(self.batch_size)]],
w_batch*np.array(self.class_weight)[y_batch.astype(int)])
w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight)
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment