diff --git a/toolkit.py b/toolkit.py index 7e1954a2bd89e6b1eece8f42508c7371de156202..9605aa4b887954c75c9ef6fa95ef178f663c014e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -878,6 +878,8 @@ class ClassificationProject(object): def plot_significance(self, lumifactor=1., significanceFunction=None): + logger.info("Plot significances") + plot_opts = dict(bins=50, range=(0, 1)) centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), weights=self.w_train[self.y_train==1], **plot_opts) centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), weights=self.w_train[self.y_train==0], **plot_opts) @@ -911,8 +913,8 @@ class ClassificationProject(object): fig, ax = plt.subplots() width = centers_sig_train[1]-centers_sig_train[0] - ax.plot(centers_bkg_train, significances_train, label="train") - ax.plot(centers_bkg_test, significances_test, label="test") + ax.plot(centers_bkg_train, significances_train, label="train, Z_max={}".format(np.amax(significances_train))) + ax.plot(centers_bkg_test, significances_test, label="test, Z_max={}".format(np.amax(significances_test))) ax.set_xlabel("Cut on NN score") ax.set_ylabel("Significance") ax.legend(loc='lower center', framealpha=0.5) @@ -993,6 +995,7 @@ class ClassificationProject(object): self.plot_loss() self.plot_score() self.plot_weights() + self.plot_significance() def create_getter(dataset_name):