Skip to content
Snippets Groups Projects
Commit 3432c091 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

ylim option for loss plot

parent 9db4ffce
No related branches found
No related tags found
No related merge requests found
......@@ -845,7 +845,7 @@ class ClassificationProject(object):
centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
except ValueError:
# weird, probably not always working workaround for a numpy bug
plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1])))
logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
......@@ -980,7 +980,7 @@ class ClassificationProject(object):
hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
return hist_dict
def plot_loss(self, all_trainings=False, log=False):
def plot_loss(self, all_trainings=False, log=False, ylim=None):
"""
Plot the value of the loss function for each epoch
......@@ -1004,6 +1004,8 @@ class ClassificationProject(object):
plt.legend(['train','test'], loc='upper left')
if log:
plt.yscale("log")
if ylim is not None:
plt.ylim(*ylim)
plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
plt.clf()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment