Skip to content
Snippets Groups Projects
Commit 3432c091 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

ylim option for loss plot

parent 9db4ffce
No related branches found
No related tags found
No related merge requests found
...@@ -845,7 +845,7 @@ class ClassificationProject(object): ...@@ -845,7 +845,7 @@ class ClassificationProject(object):
centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
except ValueError: except ValueError:
# weird, probably not always working workaround for a numpy bug # weird, probably not always working workaround for a numpy bug
plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1]))) plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1])))
logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range)) logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
...@@ -980,7 +980,7 @@ class ClassificationProject(object): ...@@ -980,7 +980,7 @@ class ClassificationProject(object):
hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
return hist_dict return hist_dict
def plot_loss(self, all_trainings=False, log=False): def plot_loss(self, all_trainings=False, log=False, ylim=None):
""" """
Plot the value of the loss function for each epoch Plot the value of the loss function for each epoch
...@@ -1004,6 +1004,8 @@ class ClassificationProject(object): ...@@ -1004,6 +1004,8 @@ class ClassificationProject(object):
plt.legend(['train','test'], loc='upper left') plt.legend(['train','test'], loc='upper left')
if log: if log:
plt.yscale("log") plt.yscale("log")
if ylim is not None:
plt.ylim(*ylim)
plt.savefig(os.path.join(self.project_dir, "losses.pdf")) plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
plt.clf() plt.clf()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment