diff --git a/toolkit.py b/toolkit.py
index d16bb2aae49f11f9113d97b9bd6c22070db3f812..b32e2152f62020da0fe6fd12228779fea1ee41a1 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1621,19 +1621,20 @@ class ClassificationProject(object):
             hist_dict = self.csv_hist
 
         logger.info("Plot losses")
-        plt.plot(hist_dict[loss_key])
-        plt.plot(hist_dict['val_'+loss_key])
-        plt.ylabel(loss_key)
-        plt.xlabel('epoch')
-        plt.legend(['training data','validation data'], loc='upper left')
+
+        fig, ax = plt.subplots()
+        ax.plot(hist_dict[loss_key])
+        ax.plot(hist_dict['val_'+loss_key])
+        ax.set_ylabel(loss_key)
+        ax.set_xlabel('epoch')
+        ax.legend(['training data','validation data'], loc='upper left')
         if log:
-            plt.yscale("log")
+            ax.set_yscale("log")
         if xlim is not None:
-            plt.xlim(*xlim)
+            ax.set_xlim(*xlim)
         if ylim is not None:
-            plt.ylim(*ylim)
-        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
-        plt.clf()
+            ax.set_ylim(*ylim)
+        return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf"))
 
 
     def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):