diff --git a/browse.py b/browse.py index b641459fb75196a2ed350fff40570c4cf1b422d4..643624ec9d9825103d68a096950878152b6b4d13 100755 --- a/browse.py +++ b/browse.py @@ -10,3 +10,10 @@ logging.basicConfig() logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) c = ClassificationProject(sys.argv[1]) + +cs = [] +cs.append(c) + +if len(sys.argv) > 2: + for project_name in sys.argv[2:]: + cs.append(ClassificationProject(project_name)) diff --git a/compare.py b/compare.py index e6496db3fc2e72170d2979729572e160cf7e0c3c..5dfb2cf2df3633ac5ca8e4182b5a8b179ceb4585 100755 --- a/compare.py +++ b/compare.py @@ -37,7 +37,13 @@ def overlay_ROC(filename, *projects): plt.savefig(filename) plt.clf() -def overlay_loss(filename, *projects): +def overlay_loss(filename, *projects, **kwargs): + + xlim = kwargs.pop("xlim", None) + ylim = kwargs.pop("ylim", None) + log = kwargs.pop("log", False) + if kwargs: + raise KeyError("Unknown kwargs: {}".format(kwargs)) logger.info("Overlay loss curves for {}".format([p.name for p in projects])) @@ -45,11 +51,18 @@ def overlay_loss(filename, *projects): colors = prop_cycle.by_key()['color'] for p,color in zip(projects,colors): - plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name, color=color) - plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name, color=color) - + hist_dict = p.csv_hist + plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) + plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) + plt.ylabel('loss') plt.xlabel('epoch') + if log: + plt.yscale("log") + if xlim is not None: + plt.xlim(*xlim) + if ylim is not None: + plt.ylim(*ylim) plt.legend(loc='upper right') plt.savefig(filename) plt.clf() diff --git a/toolkit.py b/toolkit.py index 26bdd18dc6ead8e01af7991fe8c36c94b644981c..ee4288ccc0017a172649c98edae800f5ec86ce6a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -980,7 +980,7 @@ class ClassificationProject(object): hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] return hist_dict - def plot_loss(self, all_trainings=False, log=False, ylim=None): + def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None): """ Plot the value of the loss function for each epoch @@ -1004,6 +1004,8 @@ class ClassificationProject(object): plt.legend(['train','test'], loc='upper left') if log: plt.yscale("log") + if xlim is not None: + plt.xlim(*xlim) if ylim is not None: plt.ylim(*ylim) plt.savefig(os.path.join(self.project_dir, "losses.pdf"))