diff --git a/toolkit.py b/toolkit.py index d16bb2aae49f11f9113d97b9bd6c22070db3f812..b32e2152f62020da0fe6fd12228779fea1ee41a1 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1621,19 +1621,20 @@ class ClassificationProject(object): hist_dict = self.csv_hist logger.info("Plot losses") - plt.plot(hist_dict[loss_key]) - plt.plot(hist_dict['val_'+loss_key]) - plt.ylabel(loss_key) - plt.xlabel('epoch') - plt.legend(['training data','validation data'], loc='upper left') + + fig, ax = plt.subplots() + ax.plot(hist_dict[loss_key]) + ax.plot(hist_dict['val_'+loss_key]) + ax.set_ylabel(loss_key) + ax.set_xlabel('epoch') + ax.legend(['training data','validation data'], loc='upper left') if log: - plt.yscale("log") + ax.set_yscale("log") if xlim is not None: - plt.xlim(*xlim) + ax.set_xlim(*xlim) if ylim is not None: - plt.ylim(*ylim) - plt.savefig(os.path.join(self.project_dir, "losses.pdf")) - plt.clf() + ax.set_ylim(*ylim) + return save_show(plt, fig, os.path.join(self.project_dir, "losses.pdf")) def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):