diff --git a/compare.py b/compare.py
index cfe29df3fbd35fd8fa19edc6f2e8cc5c5ce886d4..a399cc7a11f158555719e60db80b880efe2f037a 100755
--- a/compare.py
+++ b/compare.py
@@ -22,6 +22,7 @@ def overlay_ROC(filename, *projects, **kwargs):
     threshold_log = kwargs.pop("threshold_log", True)
     lumifactor = kwargs.pop("lumifactor", None)
     tight_layout = kwargs.pop("tight_layout", False)
+    show_auc = kwargs.pop("show_auc", True)
     if kwargs:
         raise KeyError("Unknown kwargs: {}".format(kwargs))
 
@@ -52,7 +53,11 @@ def overlay_ROC(filename, *projects, **kwargs):
             roc_auc = auc(tpr, fpr, reorder=True)
 
         ax.grid(color='gray', linestyle='--', linewidth=1)
-        ax.plot(tpr,  fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color)
+        if show_auc:
+            label = str(p.name+" (AUC = {:.3f})".format(roc_auc))
+        else:
+            label = p.name
+        ax.plot(tpr,  fpr, label=label, color=color)
         if plot_thresholds:
             ax2.plot(tpr, threshold, "--", color=color)
         if lumifactor is not None:
diff --git a/toolkit.py b/toolkit.py
index 67804358efbb463f23134d8bafa5f2205b01cb1a..3b5165000bc5ad6c6932b1dfc6ccece2e7a97666 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1457,19 +1457,19 @@ class ClassificationProject(object):
             hist_dict = self.csv_hist
 
         logger.info("Plot losses")
-        plt.plot(hist_dict['loss'])
-        plt.plot(hist_dict['val_loss'])
-        plt.ylabel('loss')
-        plt.xlabel('epoch')
-        plt.legend(['training data','validation data'], loc='upper left')
+        fig, ax = plt.subplots()
+        ax.plot(hist_dict['loss'])
+        ax.plot(hist_dict['val_loss'])
+        ax.set_ylabel('loss')
+        ax.set_xlabel('epoch')
+        ax.legend(['training data','validation data'], loc='upper left')
         if log:
-            plt.yscale("log")
+            ax.set_yscale("log")
         if xlim is not None:
-            plt.xlim(*xlim)
+            ax.set_xlim(*xlim)
         if ylim is not None:
-            plt.ylim(*ylim)
-        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
-        plt.clf()
+            ax.set_ylim(*ylim)
+        return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf"))
 
 
     def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):