From 3235c3eae40070d2c51f2ba258f65ae3ac9c5600 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Thu, 17 May 2018 17:11:10 +0200
Subject: [PATCH] log scale optional for plots

---
 toolkit.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 8451292..3c65869 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -822,7 +822,7 @@ class ClassificationProject(object):
         plt.clf()
 
 
-    def plot_score(self):
+    def plot_score(self, log=True):
         plot_opts = dict(bins=50, range=(0, 1))
         centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
         centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
@@ -834,9 +834,10 @@ class ClassificationProject(object):
         ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
         ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test")
         ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test")
-        ax.set_yscale("log")
+        if log:
+            ax.set_yscale("log")
         ax.set_xlabel("NN output")
-        plt.legend(loc='upper center', framealpha=0.5)
+        fig.legend(loc='upper center', framealpha=0.5)
         fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
 
 
@@ -851,7 +852,7 @@ class ClassificationProject(object):
             hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
         return hist_dict
 
-    def plot_loss(self, all_trainings=False):
+    def plot_loss(self, all_trainings=False, log=False):
         """
         Plot the value of the loss function for each epoch
 
@@ -873,11 +874,13 @@ class ClassificationProject(object):
         plt.ylabel('loss')
         plt.xlabel('epoch')
         plt.legend(['train','test'], loc='upper left')
+        if log:
+            plt.yscale("log")
         plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
         plt.clf()
 
 
-    def plot_accuracy(self, all_trainings=False):
+    def plot_accuracy(self, all_trainings=False, log=False):
         """
         Plot the value of the accuracy metric for each epoch
 
@@ -894,12 +897,15 @@ class ClassificationProject(object):
             hist_dict = self.csv_hist
 
         logger.info("Plot accuracy")
+
         plt.plot(hist_dict['acc'])
         plt.plot(hist_dict['val_acc'])
         plt.title('model accuracy')
         plt.ylabel('accuracy')
         plt.xlabel('epoch')
         plt.legend(['train', 'test'], loc='upper left')
+        if log:
+            plt.yscale("log")
         plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
         plt.clf()
 
-- 
GitLab