Skip to content
Snippets Groups Projects
Commit 7d1f43de authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Error bars for test sample in scores plot

parent 54d5cf3a
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,7 @@ import yaml ...@@ -14,6 +14,7 @@ import yaml
import pickle import pickle
import importlib import importlib
import csv import csv
import math
import logging import logging
logger = logging.getLogger("KerasROOTClassification") logger = logging.getLogger("KerasROOTClassification")
...@@ -730,9 +731,19 @@ class ClassificationProject(object): ...@@ -730,9 +731,19 @@ class ClassificationProject(object):
def get_bin_centered_hist(x, scale_factor=None, **np_kwargs): def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
hist, bins = np.histogram(x, **np_kwargs) hist, bins = np.histogram(x, **np_kwargs)
centers = (bins[:-1] + bins[1:]) / 2 centers = (bins[:-1] + bins[1:]) / 2
if "weights" in np_kwargs:
errors = []
for left, right in zip(bins, bins[1:]):
indices = np.where((x >= left) & (x <= right))[0]
sumw2 = np.sum(np_kwargs["weights"][indices]**2)
content = np.sum(np_kwargs["weights"][indices])
errors.append(math.sqrt(sumw2)/content)
errors = np.array(errors)
else:
errors = np.sqrt(hist)/hist
if scale_factor is not None: if scale_factor is not None:
hist *= scale_factor hist *= scale_factor
return centers, hist return centers, hist, errors
def plot_input(self, var_index): def plot_input(self, var_index):
...@@ -756,14 +767,14 @@ class ClassificationProject(object): ...@@ -756,14 +767,14 @@ class ClassificationProject(object):
logger.debug("Calculated range based on percentiles: {}".format(plot_range)) logger.debug("Calculated range based on percentiles: {}".format(plot_range))
try: try:
centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
except ValueError: except ValueError:
# weird, probably not always working workaround for a numpy bug # weird, probably not always working workaround for a numpy bug
plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1]))) plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range)) logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
width = centers_sig[1]-centers_sig[0] width = centers_sig[1]-centers_sig[0]
ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width) ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
...@@ -813,19 +824,19 @@ class ClassificationProject(object): ...@@ -813,19 +824,19 @@ class ClassificationProject(object):
def plot_score(self): def plot_score(self):
plot_opts = dict(bins=50, range=(0, 1)) plot_opts = dict(bins=50, range=(0, 1))
centers_sig_train, hist_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts) centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
centers_bkg_train, hist_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts) centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
centers_sig_test, hist_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts) centers_sig_test, hist_sig_test, errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
centers_bkg_test, hist_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts) centers_bkg_test, hist_bkg_test, errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
fig, ax = plt.subplots() fig, ax = plt.subplots()
width = centers_sig_train[1]-centers_sig_train[0] width = centers_sig_train[1]-centers_sig_train[0]
ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train") ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train") ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
ax.scatter(centers_bkg_test, hist_bkg_test, color="b", label="background test") ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test")
ax.scatter(centers_sig_test, hist_sig_test, color="r", label="signal test") ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test")
ax.set_yscale("log") ax.set_yscale("log")
ax.set_xlabel("NN output") ax.set_xlabel("NN output")
plt.legend(loc='upper right', framealpha=1.0) plt.legend(loc='upper center', framealpha=0.5)
fig.savefig(os.path.join(self.project_dir, "scores.pdf")) fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment