Skip to content
Snippets Groups Projects
Commit 99f3359d authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

more options for overlay functions - multiple projects in browsing script

parent 3432c091
No related branches found
No related tags found
No related merge requests found
......@@ -10,3 +10,10 @@ logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
c = ClassificationProject(sys.argv[1])
cs = []
cs.append(c)
if len(sys.argv) > 2:
for project_name in sys.argv[2:]:
cs.append(ClassificationProject(project_name))
......@@ -37,7 +37,13 @@ def overlay_ROC(filename, *projects):
plt.savefig(filename)
plt.clf()
def overlay_loss(filename, *projects):
def overlay_loss(filename, *projects, **kwargs):
xlim = kwargs.pop("xlim", None)
ylim = kwargs.pop("ylim", None)
log = kwargs.pop("log", False)
if kwargs:
raise KeyError("Unknown kwargs: {}".format(kwargs))
logger.info("Overlay loss curves for {}".format([p.name for p in projects]))
......@@ -45,11 +51,18 @@ def overlay_loss(filename, *projects):
colors = prop_cycle.by_key()['color']
for p,color in zip(projects,colors):
plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name, color=color)
plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name, color=color)
hist_dict = p.csv_hist
plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color)
plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color)
plt.ylabel('loss')
plt.xlabel('epoch')
if log:
plt.yscale("log")
if xlim is not None:
plt.xlim(*xlim)
if ylim is not None:
plt.ylim(*ylim)
plt.legend(loc='upper right')
plt.savefig(filename)
plt.clf()
......
......@@ -980,7 +980,7 @@ class ClassificationProject(object):
hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
return hist_dict
def plot_loss(self, all_trainings=False, log=False, ylim=None):
def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None):
"""
Plot the value of the loss function for each epoch
......@@ -1004,6 +1004,8 @@ class ClassificationProject(object):
plt.legend(['train','test'], loc='upper left')
if log:
plt.yscale("log")
if xlim is not None:
plt.xlim(*xlim)
if ylim is not None:
plt.ylim(*ylim)
plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment