diff --git a/toolkit.py b/toolkit.py index 598d60e18cc991759288994462c5ee603406d769..3ea752880f670e44336fdbfd6938190bfe61c9d9 100755 --- a/toolkit.py +++ b/toolkit.py @@ -811,8 +811,9 @@ class ClassificationProject(object): self.y_train.reshape(-1, 1), epochs=epochs, validation_split = self.validation_split, - class_weight=self.class_weight, - sample_weight=self.w_train, + # we have to multiply by class weight since keras ignores class weight if sample weight is given + # see https://github.com/keras-team/keras/issues/497 + sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)], shuffle=True, batch_size=self.batch_size, callbacks=self.callbacks_list)