From 1fce12f7396ee135bcd11b83fac5bf217db678e0 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Mon, 20 Aug 2018 10:17:05 +0200 Subject: [PATCH] restructure plot_score function --- toolkit.py | 49 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/toolkit.py b/toolkit.py index 52b911f..3896daf 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1097,19 +1097,36 @@ class ClassificationProject(object): plt.clf() - def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)), ylim=None, xlim=None): - centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts) - centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts) - centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts) - centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts) - errors_sig_test = hist_sig_test*rel_errors_sig_test - errors_bkg_test = hist_bkg_test*rel_errors_bkg_test + def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)), + ylim=None, xlim=None, density=True, lumifactor=None, apply_class_weight=True): fig, ax = plt.subplots() - width = centers_sig_train[1]-centers_sig_train[0] - ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train") - ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train") - ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test") - ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test") + for scores, weights, y, class_label, fn, opts in [ + (self.scores_train, self.w_train, self.y_train, 1, ax.bar, dict(color="r", label="signal train")), + (self.scores_train, self.w_train, self.y_train, 0, ax.bar, dict(color="b", label="background train")), + (self.scores_test, self.w_test, self.y_test, 1, ax.errorbar, dict(fmt="ro", label="signal test")), + (self.scores_test, self.w_test, self.y_test, 0, ax.errorbar, dict(fmt="bo", label="background test")), + ]: + weights = weights[y==class_label] + if apply_class_weight is True and (lumifactor is not None): + logger.warning("not applying class weight, since lumifactor given") + if apply_class_weight and (lumifactor is None): + weights = weights*self.class_weight[class_label] + if lumifactor is not None: + weights = weights*lumifactor + centers, hist, rel_errors = self.get_bin_centered_hist( + scores[y==class_label].reshape(-1), + weights=weights, + **plot_opts + ) + width = centers[1]-centers[0] + if density: + hist = hist/width + if fn == ax.errorbar: + errors = rel_errors*hist + opts.update(yerr=errors) + else: + opts.update(width=width, alpha=0.5) + fn(centers, hist, **opts) if log: ax.set_yscale("log") if ylim is not None: @@ -1117,7 +1134,13 @@ class ClassificationProject(object): if xlim is not None: ax.set_xlim(*xlim) ax.set_xlabel("NN output") - fig.legend(loc='upper center', framealpha=0.5) + if density: + ax.set_ylabel("dN / d(NN output)") + else: + ax.set_ylabel("Events / {:.2f}".format(width)) + if apply_class_weight: + ax.set_title("Class weights applied") + ax.legend(loc='upper center', framealpha=0.5) fig.savefig(os.path.join(self.project_dir, "scores.pdf")) plt.close(fig) -- GitLab