From f64532c3aa2c08d9c6723664f2433d94df644c70 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 20 Aug 2018 11:17:47 +0200 Subject: [PATCH] in addition to saving, display plots in interactive mode --- compare.py | 25 ++++++++++++++----------- plotting.py | 10 ++++++++++ toolkit.py | 17 +++++++++-------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/compare.py b/compare.py index 7e4c9a8..98fc4eb 100755 --- a/compare.py +++ b/compare.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc from .toolkit import ClassificationProject +from .plotting import save_show """ A few functions to compare different setups @@ -62,7 +63,7 @@ def overlay_ROC(filename, *projects, **kwargs): if plot_thresholds: # to fit right y-axis description fig.tight_layout() - fig.savefig(filename) + save_show(plt, fig, filename) plt.close(fig) def overlay_loss(filename, *projects, **kwargs): @@ -78,22 +79,24 @@ def overlay_loss(filename, *projects, **kwargs): prop_cycle = plt.rcParams['axes.prop_cycle'] colors = prop_cycle.by_key()['color'] + fig, ax = plt.subplots() + for p,color in zip(projects,colors): hist_dict = p.csv_hist - plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) - plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) + ax.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) + ax.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) - plt.ylabel('loss') - plt.xlabel('epoch') + ax.set_ylabel('loss') + ax.set_xlabel('epoch') if log: - plt.yscale("log") + ax.set_yscale("log") if xlim is not None: - plt.xlim(*xlim) + ax.set_xlim(*xlim) if ylim is not None: - plt.ylim(*ylim) - plt.legend(loc='upper right') - plt.savefig(filename) - plt.clf() + ax.set_ylim(*ylim) + ax.legend(loc='upper right') + save_show(plt, fig, filename) + plt.close(fig) diff --git a/plotting.py b/plotting.py index ae25dfc..c8a7425 100644 --- a/plotting.py +++ b/plotting.py @@ -20,6 +20,16 @@ logger.addHandler(logging.NullHandler()) Some further plotting functions """ +def save_show(plt, fig, filename): + "Save a figure and show it in case we are in ipython or jupyter notebook." + fig.savefig(filename) + try: + get_ipython + plt.show() + except NameError: + pass + + def get_mean_event(x, y, class_label): return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])] diff --git a/toolkit.py b/toolkit.py index d366a56..b02ae40 100755 --- a/toolkit.py +++ b/toolkit.py @@ -42,6 +42,7 @@ from keras import backend as K import matplotlib.pyplot as plt from .utils import WeightedRobustScaler, weighted_quantile, poisson_asimov_significance +from .plotting import save_show # configure number of cores # this doesn't seem to work, but at least with these settings keras only uses 4 processes @@ -901,10 +902,10 @@ class ClassificationProject(object): logger.info("Create/Update scores for train/test sample") if do_test: - self.scores_test = self.predict(self.x_test, mode=mode) + self.scores_test = self.predict(self.x_test, mode=mode).reshape(-1) self._dump_to_hdf5("scores_test") if do_train: - self.scores_train = self.predict(self.x_train, mode=mode) + self.scores_train = self.predict(self.x_train, mode=mode).reshape(-1) self._dump_to_hdf5("scores_train") @@ -1053,7 +1054,7 @@ class ClassificationProject(object): plot_dir = os.path.join(self.project_dir, "plots") if not os.path.exists(plot_dir): os.mkdir(plot_dir) - fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index))) + save_show(plt, fig, os.path.join(plot_dir, "var_{}.pdf".format(var_index))) plt.close(fig) @@ -1063,12 +1064,12 @@ class ClassificationProject(object): sig = self.w_train_tot[self.y_train == 1] ax.hist(bkg, bins=bins, range=range, color="b", alpha=0.5) ax.set_yscale("log") - fig.savefig(os.path.join(self.project_dir, "eventweights_bkg.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "eventweights_bkg.pdf")) plt.close(fig) fig, ax = plt.subplots() ax.hist(sig, bins=bins, range=range, color="r", alpha=0.5) ax.set_yscale("log") - fig.savefig(os.path.join(self.project_dir, "eventweights_sig.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "eventweights_sig.pdf")) plt.close(fig) @@ -1147,7 +1148,7 @@ class ClassificationProject(object): if apply_class_weight: ax.set_title("Class weights applied") ax.legend(loc='upper center', framealpha=0.5) - fig.savefig(os.path.join(self.project_dir, "scores.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "scores.pdf")) plt.close(fig) @@ -1203,7 +1204,7 @@ class ClassificationProject(object): ax.set_xlabel("Cut on NN score") ax.set_ylabel("Significance") ax.legend(loc='lower center', framealpha=0.5) - fig.savefig(os.path.join(self.project_dir, "significances_hist.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "significances_hist.pdf")) plt.close(fig) @@ -1276,7 +1277,7 @@ class ClassificationProject(object): ax.set_xlim(0, 1) ax2.set_ylabel("Threshold") ax.legend() - fig.savefig(os.path.join(self.project_dir, "significances.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "significances.pdf")) plt.close(fig) -- GitLab