From 7007d7c04804223759ef7bf23461b049496c49a6 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 29 Oct 2018 19:02:12 +0100
Subject: [PATCH] improved options for loss and ROC comparison plot

---
 compare.py |  7 ++++++-
 toolkit.py | 20 ++++++++++----------
 2 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/compare.py b/compare.py
index cfe29df..a399cc7 100755
--- a/compare.py
+++ b/compare.py
@@ -22,6 +22,7 @@ def overlay_ROC(filename, *projects, **kwargs):
     threshold_log = kwargs.pop("threshold_log", True)
     lumifactor = kwargs.pop("lumifactor", None)
     tight_layout = kwargs.pop("tight_layout", False)
+    show_auc = kwargs.pop("show_auc", True)
     if kwargs:
         raise KeyError("Unknown kwargs: {}".format(kwargs))
 
@@ -52,7 +53,11 @@ def overlay_ROC(filename, *projects, **kwargs):
             roc_auc = auc(tpr, fpr, reorder=True)
 
         ax.grid(color='gray', linestyle='--', linewidth=1)
-        ax.plot(tpr,  fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color)
+        if show_auc:
+            label = str(p.name+" (AUC = {:.3f})".format(roc_auc))
+        else:
+            label = p.name
+        ax.plot(tpr,  fpr, label=label, color=color)
         if plot_thresholds:
             ax2.plot(tpr, threshold, "--", color=color)
         if lumifactor is not None:
diff --git a/toolkit.py b/toolkit.py
index 6780435..3b51650 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1457,19 +1457,19 @@ class ClassificationProject(object):
             hist_dict = self.csv_hist
 
         logger.info("Plot losses")
-        plt.plot(hist_dict['loss'])
-        plt.plot(hist_dict['val_loss'])
-        plt.ylabel('loss')
-        plt.xlabel('epoch')
-        plt.legend(['training data','validation data'], loc='upper left')
+        fig, ax = plt.subplots()
+        ax.plot(hist_dict['loss'])
+        ax.plot(hist_dict['val_loss'])
+        ax.set_ylabel('loss')
+        ax.set_xlabel('epoch')
+        ax.legend(['training data','validation data'], loc='upper left')
         if log:
-            plt.yscale("log")
+            ax.set_yscale("log")
         if xlim is not None:
-            plt.xlim(*xlim)
+            ax.set_xlim(*xlim)
         if ylim is not None:
-            plt.ylim(*ylim)
-        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
-        plt.clf()
+            ax.set_ylim(*ylim)
+        return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf"))
 
 
     def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):
-- 
GitLab