From d23f0440cdbcbf80e16ea4de6e7459862a7804eb Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Thu, 9 Aug 2018 14:45:41 +0200
Subject: [PATCH] weighted accuracy

---
 toolkit.py | 19 ++++++++++---------
 1 file changed, 10 insertions(+), 9 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 6f24378..3c2d3d9 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -605,7 +605,8 @@ class ClassificationProject(object):
             np.random.seed(self.random_seed)
             self._model.compile(optimizer=optimizer,
                                 loss=self.loss,
-                                metrics=['accuracy'])
+                                weighted_metrics=['accuracy']
+            )
             np.random.set_state(rn_state)
             if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
                 if self.is_training:
@@ -1090,7 +1091,7 @@ class ClassificationProject(object):
         plt.clf()
 
 
-    def plot_accuracy(self, all_trainings=False, log=False):
+    def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):
         """
         Plot the value of the accuracy metric for each epoch
 
@@ -1102,14 +1103,14 @@ class ClassificationProject(object):
         else:
             hist_dict = self.history.history
 
-        if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict):
+        if (not acc_suffix in hist_dict) or (not 'val_'+acc_suffix in hist_dict):
             logger.warning("No previous history found for plotting, try global history")
             hist_dict = self.csv_hist
 
         logger.info("Plot accuracy")
 
-        plt.plot(hist_dict['acc'])
-        plt.plot(hist_dict['val_acc'])
+        plt.plot(hist_dict[acc_suffix])
+        plt.plot(hist_dict['val_'+acc_suffix])
         plt.title('model accuracy')
         plt.ylabel('accuracy')
         plt.xlabel('epoch')
@@ -1122,11 +1123,11 @@ class ClassificationProject(object):
 
     def plot_all(self):
         self.plot_ROC()
-        self.plot_accuracy()
+        # self.plot_accuracy()
         self.plot_loss()
         self.plot_score()
         self.plot_weights()
-        self.plot_significance()
+        # self.plot_significance()
 
 
 def create_getter(dataset_name):
@@ -1165,8 +1166,8 @@ if __name__ == "__main__":
                               optimizer="Adam",
                               #optimizer="SGD",
                               #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
-                                earlystopping_opts=dict(monitor='val_loss',
-                                                        min_delta=0, patience=2, verbose=0, mode='auto'),
+                              earlystopping_opts=dict(monitor='val_loss',
+                                                      min_delta=0, patience=2, verbose=0, mode='auto'),
                               selection="1",
                               branches = ["met", "mt"],
                               weight_expr = "eventWeight*genWeight",
-- 
GitLab