From 3235c3eae40070d2c51f2ba258f65ae3ac9c5600 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 17 May 2018 17:11:10 +0200 Subject: [PATCH] log scale optional for plots --- toolkit.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/toolkit.py b/toolkit.py index 8451292..3c65869 100755 --- a/toolkit.py +++ b/toolkit.py @@ -822,7 +822,7 @@ class ClassificationProject(object): plt.clf() - def plot_score(self): + def plot_score(self, log=True): plot_opts = dict(bins=50, range=(0, 1)) centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts) centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts) @@ -834,9 +834,10 @@ class ClassificationProject(object): ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train") ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test") ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test") - ax.set_yscale("log") + if log: + ax.set_yscale("log") ax.set_xlabel("NN output") - plt.legend(loc='upper center', framealpha=0.5) + fig.legend(loc='upper center', framealpha=0.5) fig.savefig(os.path.join(self.project_dir, "scores.pdf")) @@ -851,7 +852,7 @@ class ClassificationProject(object): hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] return hist_dict - def plot_loss(self, all_trainings=False): + def plot_loss(self, all_trainings=False, log=False): """ Plot the value of the loss function for each epoch @@ -873,11 +874,13 @@ class ClassificationProject(object): plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train','test'], loc='upper left') + if log: + plt.yscale("log") plt.savefig(os.path.join(self.project_dir, "losses.pdf")) plt.clf() - def plot_accuracy(self, all_trainings=False): + def plot_accuracy(self, all_trainings=False, log=False): """ Plot the value of the accuracy metric for each epoch @@ -894,12 +897,15 @@ class ClassificationProject(object): hist_dict = self.csv_hist logger.info("Plot accuracy") + plt.plot(hist_dict['acc']) plt.plot(hist_dict['val_acc']) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(['train', 'test'], loc='upper left') + if log: + plt.yscale("log") plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) plt.clf() -- GitLab