diff --git a/compare.py b/compare.py index cfe29df3fbd35fd8fa19edc6f2e8cc5c5ce886d4..a399cc7a11f158555719e60db80b880efe2f037a 100755 --- a/compare.py +++ b/compare.py @@ -22,6 +22,7 @@ def overlay_ROC(filename, *projects, **kwargs): threshold_log = kwargs.pop("threshold_log", True) lumifactor = kwargs.pop("lumifactor", None) tight_layout = kwargs.pop("tight_layout", False) + show_auc = kwargs.pop("show_auc", True) if kwargs: raise KeyError("Unknown kwargs: {}".format(kwargs)) @@ -52,7 +53,11 @@ def overlay_ROC(filename, *projects, **kwargs): roc_auc = auc(tpr, fpr, reorder=True) ax.grid(color='gray', linestyle='--', linewidth=1) - ax.plot(tpr, fpr, label=str(p.name+" (AUC = {:.3f})".format(roc_auc)), color=color) + if show_auc: + label = str(p.name+" (AUC = {:.3f})".format(roc_auc)) + else: + label = p.name + ax.plot(tpr, fpr, label=label, color=color) if plot_thresholds: ax2.plot(tpr, threshold, "--", color=color) if lumifactor is not None: diff --git a/toolkit.py b/toolkit.py index 67804358efbb463f23134d8bafa5f2205b01cb1a..3b5165000bc5ad6c6932b1dfc6ccece2e7a97666 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1457,19 +1457,19 @@ class ClassificationProject(object): hist_dict = self.csv_hist logger.info("Plot losses") - plt.plot(hist_dict['loss']) - plt.plot(hist_dict['val_loss']) - plt.ylabel('loss') - plt.xlabel('epoch') - plt.legend(['training data','validation data'], loc='upper left') + fig, ax = plt.subplots() + ax.plot(hist_dict['loss']) + ax.plot(hist_dict['val_loss']) + ax.set_ylabel('loss') + ax.set_xlabel('epoch') + ax.legend(['training data','validation data'], loc='upper left') if log: - plt.yscale("log") + ax.set_yscale("log") if xlim is not None: - plt.xlim(*xlim) + ax.set_xlim(*xlim) if ylim is not None: - plt.ylim(*ylim) - plt.savefig(os.path.join(self.project_dir, "losses.pdf")) - plt.clf() + ax.set_ylim(*ylim) + return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf")) def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):