diff --git a/compare.py b/compare.py index 7e4c9a86870dd1b01b4e733d0b9164929b049758..98fc4eb39d0fb6d0b74319c6c1794a26abda828d 100755 --- a/compare.py +++ b/compare.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc from .toolkit import ClassificationProject +from .plotting import save_show """ A few functions to compare different setups @@ -62,7 +63,7 @@ def overlay_ROC(filename, *projects, **kwargs): if plot_thresholds: # to fit right y-axis description fig.tight_layout() - fig.savefig(filename) + save_show(plt, fig, filename) plt.close(fig) def overlay_loss(filename, *projects, **kwargs): @@ -78,22 +79,24 @@ def overlay_loss(filename, *projects, **kwargs): prop_cycle = plt.rcParams['axes.prop_cycle'] colors = prop_cycle.by_key()['color'] + fig, ax = plt.subplots() + for p,color in zip(projects,colors): hist_dict = p.csv_hist - plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) - plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) + ax.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color) + ax.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color) - plt.ylabel('loss') - plt.xlabel('epoch') + ax.set_ylabel('loss') + ax.set_xlabel('epoch') if log: - plt.yscale("log") + ax.set_yscale("log") if xlim is not None: - plt.xlim(*xlim) + ax.set_xlim(*xlim) if ylim is not None: - plt.ylim(*ylim) - plt.legend(loc='upper right') - plt.savefig(filename) - plt.clf() + ax.set_ylim(*ylim) + ax.legend(loc='upper right') + save_show(plt, fig, filename) + plt.close(fig) diff --git a/plotting.py b/plotting.py index ae25dfc096feffa92f3470e134155b7439638f6c..c8a7425921d1daae042ce5d645ceb536dfb46074 100644 --- a/plotting.py +++ b/plotting.py @@ -20,6 +20,16 @@ logger.addHandler(logging.NullHandler()) Some further plotting functions """ +def save_show(plt, fig, filename): + "Save a figure and show it in case we are in ipython or jupyter notebook." + fig.savefig(filename) + try: + get_ipython + plt.show() + except NameError: + pass + + def get_mean_event(x, y, class_label): return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])] diff --git a/toolkit.py b/toolkit.py index d366a560728d9ec1f5f95f93cd88f13f62a6012f..b02ae40060513a5a7a34e8d3c99e49e932102b06 100755 --- a/toolkit.py +++ b/toolkit.py @@ -42,6 +42,7 @@ from keras import backend as K import matplotlib.pyplot as plt from .utils import WeightedRobustScaler, weighted_quantile, poisson_asimov_significance +from .plotting import save_show # configure number of cores # this doesn't seem to work, but at least with these settings keras only uses 4 processes @@ -901,10 +902,10 @@ class ClassificationProject(object): logger.info("Create/Update scores for train/test sample") if do_test: - self.scores_test = self.predict(self.x_test, mode=mode) + self.scores_test = self.predict(self.x_test, mode=mode).reshape(-1) self._dump_to_hdf5("scores_test") if do_train: - self.scores_train = self.predict(self.x_train, mode=mode) + self.scores_train = self.predict(self.x_train, mode=mode).reshape(-1) self._dump_to_hdf5("scores_train") @@ -1053,7 +1054,7 @@ class ClassificationProject(object): plot_dir = os.path.join(self.project_dir, "plots") if not os.path.exists(plot_dir): os.mkdir(plot_dir) - fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index))) + save_show(plt, fig, os.path.join(plot_dir, "var_{}.pdf".format(var_index))) plt.close(fig) @@ -1063,12 +1064,12 @@ class ClassificationProject(object): sig = self.w_train_tot[self.y_train == 1] ax.hist(bkg, bins=bins, range=range, color="b", alpha=0.5) ax.set_yscale("log") - fig.savefig(os.path.join(self.project_dir, "eventweights_bkg.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "eventweights_bkg.pdf")) plt.close(fig) fig, ax = plt.subplots() ax.hist(sig, bins=bins, range=range, color="r", alpha=0.5) ax.set_yscale("log") - fig.savefig(os.path.join(self.project_dir, "eventweights_sig.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "eventweights_sig.pdf")) plt.close(fig) @@ -1147,7 +1148,7 @@ class ClassificationProject(object): if apply_class_weight: ax.set_title("Class weights applied") ax.legend(loc='upper center', framealpha=0.5) - fig.savefig(os.path.join(self.project_dir, "scores.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "scores.pdf")) plt.close(fig) @@ -1203,7 +1204,7 @@ class ClassificationProject(object): ax.set_xlabel("Cut on NN score") ax.set_ylabel("Significance") ax.legend(loc='lower center', framealpha=0.5) - fig.savefig(os.path.join(self.project_dir, "significances_hist.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "significances_hist.pdf")) plt.close(fig) @@ -1276,7 +1277,7 @@ class ClassificationProject(object): ax.set_xlim(0, 1) ax2.set_ylabel("Threshold") ax.legend() - fig.savefig(os.path.join(self.project_dir, "significances.pdf")) + save_show(plt, fig, os.path.join(self.project_dir, "significances.pdf")) plt.close(fig)