From d23f0440cdbcbf80e16ea4de6e7459862a7804eb Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Thu, 9 Aug 2018 14:45:41 +0200 Subject: [PATCH] weighted accuracy --- toolkit.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/toolkit.py b/toolkit.py index 6f24378..3c2d3d9 100755 --- a/toolkit.py +++ b/toolkit.py @@ -605,7 +605,8 @@ class ClassificationProject(object): np.random.seed(self.random_seed) self._model.compile(optimizer=optimizer, loss=self.loss, - metrics=['accuracy']) + weighted_metrics=['accuracy'] + ) np.random.set_state(rn_state) if os.path.exists(os.path.join(self.project_dir, "weights.h5")): if self.is_training: @@ -1090,7 +1091,7 @@ class ClassificationProject(object): plt.clf() - def plot_accuracy(self, all_trainings=False, log=False): + def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"): """ Plot the value of the accuracy metric for each epoch @@ -1102,14 +1103,14 @@ class ClassificationProject(object): else: hist_dict = self.history.history - if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict): + if (not acc_suffix in hist_dict) or (not 'val_'+acc_suffix in hist_dict): logger.warning("No previous history found for plotting, try global history") hist_dict = self.csv_hist logger.info("Plot accuracy") - plt.plot(hist_dict['acc']) - plt.plot(hist_dict['val_acc']) + plt.plot(hist_dict[acc_suffix]) + plt.plot(hist_dict['val_'+acc_suffix]) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') @@ -1122,11 +1123,11 @@ class ClassificationProject(object): def plot_all(self): self.plot_ROC() - self.plot_accuracy() + # self.plot_accuracy() self.plot_loss() self.plot_score() self.plot_weights() - self.plot_significance() + # self.plot_significance() def create_getter(dataset_name): @@ -1165,8 +1166,8 @@ if __name__ == "__main__": optimizer="Adam", #optimizer="SGD", #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), - earlystopping_opts=dict(monitor='val_loss', - min_delta=0, patience=2, verbose=0, mode='auto'), + earlystopping_opts=dict(monitor='val_loss', + min_delta=0, patience=2, verbose=0, mode='auto'), selection="1", branches = ["met", "mt"], weight_expr = "eventWeight*genWeight", -- GitLab