diff --git a/toolkit.py b/toolkit.py index 7f6cad06fcbc3af2f3504e5da1913511f2d90475..75cdbc1ec01c3a26c30ec3bc2bdcfda07477bb0a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -873,8 +873,7 @@ class ClassificationProject(object): plt.clf() - def plot_score(self, log=True): - plot_opts = dict(bins=50, range=(0, 1)) + def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1))): centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts) centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts) centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts) @@ -894,10 +893,9 @@ class ClassificationProject(object): fig.savefig(os.path.join(self.project_dir, "scores.pdf")) - def plot_significance(self, lumifactor=1., significanceFunction=None): + def plot_significance(self, lumifactor=1., significanceFunction=None, plot_opts=dict(bins=50, range=(0, 1))): logger.info("Plot significances") - plot_opts = dict(bins=50, range=(0, 1)) centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), weights=self.w_train[self.y_train==1], **plot_opts) centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), weights=self.w_train[self.y_train==0], **plot_opts) centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), weights=self.w_test[self.y_test==1], **plot_opts)