diff --git a/toolkit.py b/toolkit.py
index 6098d68219839ddc61b6f4dadcbebbaa0c8e13b1..761dd7a6dd03e39b1abb265e21c695a32a4b9f70 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -20,6 +20,7 @@ import glob
 import shutil
 import gc
 import random
+from tqdm import tqdm
 
 import logging
 logger = logging.getLogger("KerasROOTClassification")
@@ -816,7 +817,7 @@ class ClassificationProject(object):
         return self._model
 
 
-    def _compile_or_load_model(self):
+    def _compile_model(self):
         logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
         Optimizer = getattr(keras.optimizers, self.optimizer)
         optimizer = Optimizer(**self.optimizer_opts)
@@ -830,6 +831,11 @@ class ClassificationProject(object):
         )
         np.random.set_state(rn_state)
 
+
+    def _compile_or_load_model(self):
+
+        self._compile_model()
+
         if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
             if self.is_training:
                 continue_training = self.query_yn("Found previously trained weights - "
@@ -2026,6 +2032,9 @@ class ClassificationProjectDecorr(ClassificationProject):
         self._class_layers = None
         self._adv_hidden_layers = None
         self._adv_target_layers = None
+        self._class_output = None
+        self._adv_outputs = None
+        self._model_adv = None
 
 
     def load(self, *args, **kwargs):
@@ -2059,14 +2068,19 @@ class ClassificationProjectDecorr(ClassificationProject):
 
     def get_weight_list(self, w, y):
         w_list = super(ClassificationProjectDecorr, self).get_weight_list(w)
-        # copy first entry (the others might be references)
-        w_list[0] = np.array(w_list[0])
-        for w in w_list[1:]:
+        for i, w in enumerate(w_list[1:], 1):
+            # copy entry (the others might be references)
+            w_list[i] = np.array(w_list[i])
             # set signal weights to 0 for decorr target
-            w[y[:,0]==1] = 0.
+            w_list[i][y[:,0]==1] = 0.
         return w_list
 
 
+    def transform_target(self, y):
+        # not needed here
+        return y
+
+
     @property
     def class_layers(self):
         """
@@ -2116,12 +2130,121 @@ class ClassificationProjectDecorr(ClassificationProject):
 
     @property
     def class_input(self):
-        pass
+        return self._class_layers[0]
+
+
+    @property
+    def class_output(self):
+        if self._class_output is None:
+            out = None
+            for layer in self.class_layers:
+                if out is None:
+                    out = layer
+                else:
+                    out = layer(out)
+            self._class_output = out
+        return self._class_output
+
+
+    @property
+    def adv_outputs(self):
+        if self._adv_outputs is None:
+            out = self.class_output
+            for layer in self._adv_hidden_layers:
+                out = layer(out)
+            outs = []
+            for layer in self._adv_target_layers:
+                outs.append(layer(out))
+            self._adv_outputs = outs
+        return self._adv_outputs
+
+
+    def _compile_model(self):
+        logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
+        Optimizer = getattr(keras.optimizers, self.optimizer)
+        optimizer = Optimizer(**self.optimizer_opts)
+        logger.info("Compile model")
+        rn_state = np.random.get_state()
+        np.random.seed(self.random_seed)
+        self._model.compile(
+            optimizer=optimizer,
+            loss=["binary_crossentropy"]+["categorical_crossentropy"]*len(self.adv_outputs),
+            # define the "lambda" from arXiv:1703.03507 via the loss weights
+            loss_weights=self.loss_weights,
+            weighted_metrics=['accuracy']
+        )
+        np.random.set_state(rn_state)
+
+
+    @staticmethod
+    def set_trainability(layers, trainable):
+        for layer in layers:
+            layer.trainable = trainable
 
 
     @property
     def model(self):
-        pass
+        """
+        This is the classification model + penalty term from the adversary
+        """
+        if self._model is None:
+            self._model = Model(
+                inputs=[self.class_input],
+                outputs=[self.class_output]+self.adv_outputs
+            )
+            # classification model only adjusts classification weights
+            self.set_trainability(self.class_layers, True)
+            self.set_trainability(self.adv_layers, False)
+            self._compile_or_load_model()
+        return self._model
+
+
+    @property
+    def model_adv(self):
+        """
+        Adversarial model that tries to reconstruct distributions that should
+        be uncorrelated from classification output.
+        """
+        if self._model_adv is None:
+            self._model_adv = Model(
+                inputs=[self.class_input],
+                outputs=self.adv_outputs
+            )
+            # adversarial model only adjusts adversarial weights
+            self.set_trainability(self.class_layers, False)
+            self.set_trainability(self.adv_layers, True)
+            self._model_adv.compile(
+                optimizer=keras.optimizers.adam(lr=0.01),
+                loss="categorical_crossentropy",
+            )
+        return self._model_adv
+
+
+    def train(self, epochs=10):
+        batch_generator = self.yield_batch()
+        metric_list = []
+        for epoch in range(epochs):
+            logger.info("Fitting epoch {}".format(epoch))
+            metrics = None
+            avg_metrics = None
+            for batch_id in tqdm(range(self.steps_per_epoch)):
+                x, y, w = next(batch_generator)
+                # fit the classifier
+                batch_metrics = self.model.train_on_batch(
+                    x, y, sample_weight=w
+                )
+                # fit the adversary
+                self.model_adv.train_on_batch(
+                    x, y[1:], sample_weight=w[1:]
+                )
+                batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics))
+                if metrics is None:
+                    metrics = batch_metrics
+                else:
+                    metrics = np.concatenate([metrics, batch_metrics])
+                avg_metrics = np.mean(metrics, axis=0)
+            metric_list.append(avg_metrics)
+        return metric_list
 
 
 if __name__ == "__main__":