diff --git a/toolkit.py b/toolkit.py index 6098d68219839ddc61b6f4dadcbebbaa0c8e13b1..761dd7a6dd03e39b1abb265e21c695a32a4b9f70 100755 --- a/toolkit.py +++ b/toolkit.py @@ -20,6 +20,7 @@ import glob import shutil import gc import random +from tqdm import tqdm import logging logger = logging.getLogger("KerasROOTClassification") @@ -816,7 +817,7 @@ class ClassificationProject(object): return self._model - def _compile_or_load_model(self): + def _compile_model(self): logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) Optimizer = getattr(keras.optimizers, self.optimizer) optimizer = Optimizer(**self.optimizer_opts) @@ -830,6 +831,11 @@ class ClassificationProject(object): ) np.random.set_state(rn_state) + + def _compile_or_load_model(self): + + self._compile_model() + if os.path.exists(os.path.join(self.project_dir, "weights.h5")): if self.is_training: continue_training = self.query_yn("Found previously trained weights - " @@ -2026,6 +2032,9 @@ class ClassificationProjectDecorr(ClassificationProject): self._class_layers = None self._adv_hidden_layers = None self._adv_target_layers = None + self._class_output = None + self._adv_outputs = None + self._model_adv = None def load(self, *args, **kwargs): @@ -2059,14 +2068,19 @@ class ClassificationProjectDecorr(ClassificationProject): def get_weight_list(self, w, y): w_list = super(ClassificationProjectDecorr, self).get_weight_list(w) - # copy first entry (the others might be references) - w_list[0] = np.array(w_list[0]) - for w in w_list[1:]: + for i, w in enumerate(w_list[1:], 1): + # copy entry (the others might be references) + w_list[i] = np.array(w_list[i]) # set signal weights to 0 for decorr target - w[y[:,0]==1] = 0. + w_list[i][y[:,0]==1] = 0. return w_list + def transform_target(self, y): + # not needed here + return y + + @property def class_layers(self): """ @@ -2116,12 +2130,121 @@ class ClassificationProjectDecorr(ClassificationProject): @property def class_input(self): - pass + return self._class_layers[0] + + + @property + def class_output(self): + if self._class_output is None: + out = None + for layer in self.class_layers: + if out is None: + out = layer + else: + out = layer(out) + self._class_output = out + return self._class_output + + + @property + def adv_outputs(self): + if self._adv_outputs is None: + out = self.class_output + for layer in self._adv_hidden_layers: + out = layer(out) + outs = [] + for layer in self._adv_target_layers: + outs.append(layer(out)) + self._adv_outputs = outs + return self._adv_outputs + + + def _compile_model(self): + logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) + Optimizer = getattr(keras.optimizers, self.optimizer) + optimizer = Optimizer(**self.optimizer_opts) + logger.info("Compile model") + rn_state = np.random.get_state() + np.random.seed(self.random_seed) + self._model.compile( + optimizer=optimizer, + loss=["binary_crossentropy"]+["categorical_crossentropy"]*len(self.adv_outputs), + # define the "lambda" from arXiv:1703.03507 via the loss weights + loss_weights=self.loss_weights, + weighted_metrics=['accuracy'] + ) + np.random.set_state(rn_state) + + + @staticmethod + def set_trainability(layers, trainable): + for layer in layers: + layer.trainable = trainable @property def model(self): - pass + """ + This is the classification model + penalty term from the adversary + """ + if self._model is None: + self._model = Model( + inputs=[self.class_input], + outputs=[self.class_output]+self.adv_outputs + ) + # classification model only adjusts classification weights + self.set_trainability(self.class_layers, True) + self.set_trainability(self.adv_layers, False) + self._compile_or_load_model() + return self._model + + + @property + def model_adv(self): + """ + Adversarial model that tries to reconstruct distributions that should + be uncorrelated from classification output. + """ + if self._model_adv is None: + self._model_adv = Model( + inputs=[self.class_input], + outputs=self.adv_outputs + ) + # adversarial model only adjusts adversarial weights + self.set_trainability(self.class_layers, False) + self.set_trainability(self.adv_layers, True) + self._model_adv.compile( + optimizer=keras.optimizers.adam(lr=0.01), + loss="categorical_crossentropy", + ) + return self._model_adv + + + def train(self, epochs=10): + batch_generator = self.yield_batch() + metric_list = [] + for epoch in range(epochs): + logger.info("Fitting epoch {}".format(epoch)) + metrics = None + avg_metrics = None + for batch_id in tqdm(range(self.steps_per_epoch)): + x, y, w = next(batch_generator) + # fit the classifier + batch_metrics = self.model.train_on_batch( + x, y, sample_weight=w + ) + # fit the adversary + self.model_adv.train_on_batch( + x, y[1:], sample_weight=w[1:] + ) + batch_metrics = np.array(batch_metrics).reshape(1, len(batch_metrics)) + if metrics is None: + metrics = batch_metrics + else: + metrics = np.concatenate([metrics, batch_metrics]) + avg_metrics = np.mean(metrics, axis=0) + metric_list.append(avg_metrics) + return metric_list if __name__ == "__main__":