diff --git a/toolkit.py b/toolkit.py index 6e2b819af16776af41daf6af8c209bf6352e54b6..b0500fa5da7675bb53f16b23968ef124b62f230b 100755 --- a/toolkit.py +++ b/toolkit.py @@ -2027,10 +2027,12 @@ class ClassificationProjectDecorr(ClassificationProject): def _init_from_args(self, name, decorr_bins=10, + adv_lr=0.001, **kwargs): super(ClassificationProjectDecorr, self)._init_from_args(name, **kwargs) self.decorr_binnings = [] self.decorr_bins = decorr_bins + self.adv_lr = adv_lr self._write_info("project_type", "ClassificationProjectDecorr") self._class_layers = None self._adv_hidden_layers = None @@ -2247,7 +2249,7 @@ class ClassificationProjectDecorr(ClassificationProject): self.set_trainability(self.class_layers, False) self.set_trainability(self.adv_layers, True) self._model_adv.compile( - optimizer=keras.optimizers.adam(lr=0.001), + optimizer=keras.optimizers.adam(lr=self.adv_lr), loss="categorical_crossentropy", ) return self._model_adv