From 99f3359d73d4ae2449f5e1dd9ac6f25b00943d16 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 23 Jul 2018 18:01:14 +0200 Subject: [PATCH] more options for overlay functions - multiple projects in browsing script --- browse.py | 7 +++++++ compare.py | 21 +++++++++++++++++---- toolkit.py | 4 +++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/browse.py b/browse.py index b641459..643624e 100755 --- a/browse.py +++ b/browse.py @@ -10,3 +10,10 @@ logging.basicConfig() logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) c = ClassificationProject(sys.argv[1]) + +cs = [] +cs.append(c) + +if len(sys.argv) > 2: + for project_name in sys.argv[2:]: + cs.append(ClassificationProject(project_name)) diff --git a/compare.py b/compare.py index e6496db..5dfb2cf 100755 --- a/compare.py +++ b/compare.py @@ -37,7 +37,13 @@ def overlay_ROC(filename, *projects): plt.savefig(filename) plt.clf() -def overlay_loss(filename, *projects): +def overlay_loss(filename, *projects, **kwargs): + + xlim = kwargs.pop("xlim", None) + ylim = kwargs.pop("ylim", None) + log = kwargs.pop("log", False) + if kwargs: + raise KeyError("Unknown kwargs: {}".format(kwargs)) logger.info("Overlay loss curves for {}".format([p.name for p in projects])) @@ -45,11 +51,18 @@ def overlay_loss(filename, *projects): colors = prop_cycle.by_key()['color'] for p,color in zip(projects,colors): - plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name, color=color) - plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name, color=color) - + hist_dict = p.csv_hist + plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) + plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) + plt.ylabel('loss') plt.xlabel('epoch') + if log: + plt.yscale("log") + if xlim is not None: + plt.xlim(*xlim) + if ylim is not None: + plt.ylim(*ylim) plt.legend(loc='upper right') plt.savefig(filename) plt.clf() diff --git a/toolkit.py b/toolkit.py index 26bdd18..ee4288c 100755 --- a/toolkit.py +++ b/toolkit.py @@ -980,7 +980,7 @@ class ClassificationProject(object): hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] return hist_dict - def plot_loss(self, all_trainings=False, log=False, ylim=None): + def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None): """ Plot the value of the loss function for each epoch @@ -1004,6 +1004,8 @@ class ClassificationProject(object): plt.legend(['train','test'], loc='upper left') if log: plt.yscale("log") + if xlim is not None: + plt.xlim(*xlim) if ylim is not None: plt.ylim(*ylim) plt.savefig(os.path.join(self.project_dir, "losses.pdf")) -- GitLab