From a741cd01a4882bd75f81485683b5afb255480690 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 29 Nov 2018 14:34:57 +0100 Subject: [PATCH] fixing loss plot --- toolkit.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/toolkit.py b/toolkit.py index d16bb2a..b32e215 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1621,19 +1621,20 @@ class ClassificationProject(object): hist_dict = self.csv_hist logger.info("Plot losses") - plt.plot(hist_dict[loss_key]) - plt.plot(hist_dict['val_'+loss_key]) - plt.ylabel(loss_key) - plt.xlabel('epoch') - plt.legend(['training data','validation data'], loc='upper left') + + fig, ax = plt.subplots() + ax.plot(hist_dict[loss_key]) + ax.plot(hist_dict['val_'+loss_key]) + ax.set_ylabel(loss_key) + ax.set_xlabel('epoch') + ax.legend(['training data','validation data'], loc='upper left') if log: - plt.yscale("log") + ax.set_yscale("log") if xlim is not None: - plt.xlim(*xlim) + ax.set_xlim(*xlim) if ylim is not None: - plt.ylim(*ylim) - plt.savefig(os.path.join(self.project_dir, "losses.pdf")) - plt.clf() + ax.set_ylim(*ylim) + return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf")) def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"): -- GitLab