diff --git a/toolkit.py b/toolkit.py index d84d0c1fe7a8d2b16ddb4e710be246ec387562ca..0bc2c2918098b2857c657110ba73893501942f0e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -877,6 +877,47 @@ class ClassificationProject(object): fig.savefig(os.path.join(self.project_dir, "scores.pdf")) + def plot_significance(self, lumifactor=1., significanceFunction=None): + plot_opts = dict(bins=50, range=(0, 1)) + centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), weights=self.w_train[self.y_train==1], **plot_opts) + centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), weights=self.w_train[self.y_train==0], **plot_opts) + centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), weights=self.w_test[self.y_test==1], **plot_opts) + centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), weights=self.w_test[self.y_test==0], **plot_opts) + + significances_train = [] + significances_test = [] + for hist_sig, hist_bkg, rel_errors_sig, rel_errors_bkg, significances in [ + (hist_sig_train, hist_bkg_train, rel_errors_bkg_train, rel_errors_sig_train, significances_train), + (hist_sig_test, hist_bkg_test, rel_errors_bkg_test, rel_errors_sig_test, significances_test), + ]: + # first set nan values to 0 and multiply by lumi + for arr in hist_sig, hist_bkg, rel_errors_bkg: + arr[np.isnan(arr)] = 0 + hist_sig *= lumifactor + hist_bkg *= lumifactor + for i in range(len(hist_sig)): + s = sum(hist_sig[i:]) + b = sum(hist_bkg[i:]) + db = math.sqrt(sum((rel_errors_bkg[i:]*hist_bkg[i:])**2)) + if significanceFunction is None: + try: + z = s/math.sqrt(b+db**2) + except (ZeroDivisionError, ValueError) as e: + z = 0 + else: + z = significanceFunction(s, b, db) + logger.debug("s, b, db, z = {}, {}, {}, {}".format(s, b, db, z)) + significances.append(z) + + fig, ax = plt.subplots() + width = centers_sig_train[1]-centers_sig_train[0] + ax.plot(centers_bkg_train, significances_train, label="train") + ax.plot(centers_bkg_test, significances_test, label="test") + ax.set_xlabel("Cut on NN score") + ax.set_ylabel("Significance") + ax.legend(loc='lower center', framealpha=0.5) + fig.savefig(os.path.join(self.project_dir, "significances.pdf")) + @property def csv_hist(self):