From a741cd01a4882bd75f81485683b5afb255480690 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Thu, 29 Nov 2018 14:34:57 +0100
Subject: [PATCH] fixing loss plot

---
 toolkit.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index d16bb2a..b32e215 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1621,19 +1621,20 @@ class ClassificationProject(object):
             hist_dict = self.csv_hist
 
         logger.info("Plot losses")
-        plt.plot(hist_dict[loss_key])
-        plt.plot(hist_dict['val_'+loss_key])
-        plt.ylabel(loss_key)
-        plt.xlabel('epoch')
-        plt.legend(['training data','validation data'], loc='upper left')
+
+        fig, ax = plt.subplots()
+        ax.plot(hist_dict[loss_key])
+        ax.plot(hist_dict['val_'+loss_key])
+        ax.set_ylabel(loss_key)
+        ax.set_xlabel('epoch')
+        ax.legend(['training data','validation data'], loc='upper left')
         if log:
-            plt.yscale("log")
+            ax.set_yscale("log")
         if xlim is not None:
-            plt.xlim(*xlim)
+            ax.set_xlim(*xlim)
         if ylim is not None:
-            plt.ylim(*ylim)
-        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
-        plt.clf()
+            ax.set_ylim(*ylim)
+        return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf"))
 
 
     def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):
-- 
GitLab