Skip to content
Snippets Groups Projects
Commit 1fce12f7 authored by Nikolai's avatar Nikolai
Browse files

restructure plot_score function

parent cb83fad4
No related branches found
No related tags found
No related merge requests found
......@@ -1097,19 +1097,36 @@ class ClassificationProject(object):
plt.clf()
def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)), ylim=None, xlim=None):
centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
errors_sig_test = hist_sig_test*rel_errors_sig_test
errors_bkg_test = hist_bkg_test*rel_errors_bkg_test
def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)),
ylim=None, xlim=None, density=True, lumifactor=None, apply_class_weight=True):
fig, ax = plt.subplots()
width = centers_sig_train[1]-centers_sig_train[0]
ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test")
ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test")
for scores, weights, y, class_label, fn, opts in [
(self.scores_train, self.w_train, self.y_train, 1, ax.bar, dict(color="r", label="signal train")),
(self.scores_train, self.w_train, self.y_train, 0, ax.bar, dict(color="b", label="background train")),
(self.scores_test, self.w_test, self.y_test, 1, ax.errorbar, dict(fmt="ro", label="signal test")),
(self.scores_test, self.w_test, self.y_test, 0, ax.errorbar, dict(fmt="bo", label="background test")),
]:
weights = weights[y==class_label]
if apply_class_weight is True and (lumifactor is not None):
logger.warning("not applying class weight, since lumifactor given")
if apply_class_weight and (lumifactor is None):
weights = weights*self.class_weight[class_label]
if lumifactor is not None:
weights = weights*lumifactor
centers, hist, rel_errors = self.get_bin_centered_hist(
scores[y==class_label].reshape(-1),
weights=weights,
**plot_opts
)
width = centers[1]-centers[0]
if density:
hist = hist/width
if fn == ax.errorbar:
errors = rel_errors*hist
opts.update(yerr=errors)
else:
opts.update(width=width, alpha=0.5)
fn(centers, hist, **opts)
if log:
ax.set_yscale("log")
if ylim is not None:
......@@ -1117,7 +1134,13 @@ class ClassificationProject(object):
if xlim is not None:
ax.set_xlim(*xlim)
ax.set_xlabel("NN output")
fig.legend(loc='upper center', framealpha=0.5)
if density:
ax.set_ylabel("dN / d(NN output)")
else:
ax.set_ylabel("Events / {:.2f}".format(width))
if apply_class_weight:
ax.set_title("Class weights applied")
ax.legend(loc='upper center', framealpha=0.5)
fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
plt.close(fig)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment