diff --git a/toolkit.py b/toolkit.py index 3c50858a06681071642d5119f8559c2c17281eff..646f09fc5337775274364158e0ddf2ad777e0de6 100755 --- a/toolkit.py +++ b/toolkit.py @@ -619,7 +619,8 @@ class ClassificationProject(object): np.random.seed(self.random_seed) self._model.compile(optimizer=optimizer, loss=self.loss, - metrics=['accuracy']) + weighted_metrics=['accuracy'] + ) np.random.set_state(rn_state) if os.path.exists(os.path.join(self.project_dir, "weights.h5")): @@ -1106,7 +1107,7 @@ class ClassificationProject(object): plt.clf() - def plot_accuracy(self, all_trainings=False, log=False): + def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"): """ Plot the value of the accuracy metric for each epoch @@ -1118,14 +1119,14 @@ class ClassificationProject(object): else: hist_dict = self.history.history - if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict): + if (not acc_suffix in hist_dict) or (not 'val_'+acc_suffix in hist_dict): logger.warning("No previous history found for plotting, try global history") hist_dict = self.csv_hist logger.info("Plot accuracy") - plt.plot(hist_dict['acc']) - plt.plot(hist_dict['val_acc']) + plt.plot(hist_dict[acc_suffix]) + plt.plot(hist_dict['val_'+acc_suffix]) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') @@ -1138,11 +1139,11 @@ class ClassificationProject(object): def plot_all(self): self.plot_ROC() - self.plot_accuracy() + # self.plot_accuracy() self.plot_loss() self.plot_score() self.plot_weights() - self.plot_significance() + # self.plot_significance() def create_getter(dataset_name): @@ -1181,8 +1182,8 @@ if __name__ == "__main__": optimizer="Adam", #optimizer="SGD", #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), - earlystopping_opts=dict(monitor='val_loss', - min_delta=0, patience=2, verbose=0, mode='auto'), + earlystopping_opts=dict(monitor='val_loss', + min_delta=0, patience=2, verbose=0, mode='auto'), selection="1", branches = ["met", "mt"], weight_expr = "eventWeight*genWeight",