diff --git a/toolkit.py b/toolkit.py index e424460e69de51bd585a21bede5d12fc6538fe8e..e9128817576467c28eeafa78d264a3e28c1f7d9e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -828,7 +828,7 @@ class ClassificationProject(object): validation_split = self.validation_split, # we have to multiply by class weight since keras ignores class weight if sample weight is given # see https://github.com/keras-team/keras/issues/497 - sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)], + sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]/self.mean_train_weight, shuffle=True, batch_size=self.batch_size, callbacks=self.callbacks_list)