Skip to content
Snippets Groups Projects
Commit 3235c3ea authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

log scale optional for plots

parent da18e449
No related branches found
No related tags found
No related merge requests found
...@@ -822,7 +822,7 @@ class ClassificationProject(object): ...@@ -822,7 +822,7 @@ class ClassificationProject(object):
plt.clf() plt.clf()
def plot_score(self): def plot_score(self, log=True):
plot_opts = dict(bins=50, range=(0, 1)) plot_opts = dict(bins=50, range=(0, 1))
centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts) centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts) centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
...@@ -834,9 +834,10 @@ class ClassificationProject(object): ...@@ -834,9 +834,10 @@ class ClassificationProject(object):
ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train") ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test") ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test")
ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test") ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test")
ax.set_yscale("log") if log:
ax.set_yscale("log")
ax.set_xlabel("NN output") ax.set_xlabel("NN output")
plt.legend(loc='upper center', framealpha=0.5) fig.legend(loc='upper center', framealpha=0.5)
fig.savefig(os.path.join(self.project_dir, "scores.pdf")) fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
...@@ -851,7 +852,7 @@ class ClassificationProject(object): ...@@ -851,7 +852,7 @@ class ClassificationProject(object):
hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
return hist_dict return hist_dict
def plot_loss(self, all_trainings=False): def plot_loss(self, all_trainings=False, log=False):
""" """
Plot the value of the loss function for each epoch Plot the value of the loss function for each epoch
...@@ -873,11 +874,13 @@ class ClassificationProject(object): ...@@ -873,11 +874,13 @@ class ClassificationProject(object):
plt.ylabel('loss') plt.ylabel('loss')
plt.xlabel('epoch') plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left') plt.legend(['train','test'], loc='upper left')
if log:
plt.yscale("log")
plt.savefig(os.path.join(self.project_dir, "losses.pdf")) plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
plt.clf() plt.clf()
def plot_accuracy(self, all_trainings=False): def plot_accuracy(self, all_trainings=False, log=False):
""" """
Plot the value of the accuracy metric for each epoch Plot the value of the accuracy metric for each epoch
...@@ -894,12 +897,15 @@ class ClassificationProject(object): ...@@ -894,12 +897,15 @@ class ClassificationProject(object):
hist_dict = self.csv_hist hist_dict = self.csv_hist
logger.info("Plot accuracy") logger.info("Plot accuracy")
plt.plot(hist_dict['acc']) plt.plot(hist_dict['acc'])
plt.plot(hist_dict['val_acc']) plt.plot(hist_dict['val_acc'])
plt.title('model accuracy') plt.title('model accuracy')
plt.ylabel('accuracy') plt.ylabel('accuracy')
plt.xlabel('epoch') plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left') plt.legend(['train', 'test'], loc='upper left')
if log:
plt.yscale("log")
plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
plt.clf() plt.clf()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment