diff --git a/toolkit.py b/toolkit.py
index d84d0c1fe7a8d2b16ddb4e710be246ec387562ca..0bc2c2918098b2857c657110ba73893501942f0e 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -877,6 +877,47 @@ class ClassificationProject(object):
         fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
 
 
+    def plot_significance(self, lumifactor=1., significanceFunction=None):
+        plot_opts = dict(bins=50, range=(0, 1))
+        centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), weights=self.w_train[self.y_train==1], **plot_opts)
+        centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), weights=self.w_train[self.y_train==0], **plot_opts)
+        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), weights=self.w_test[self.y_test==1], **plot_opts)
+        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), weights=self.w_test[self.y_test==0], **plot_opts)
+
+        significances_train = []
+        significances_test = []
+        for hist_sig, hist_bkg, rel_errors_sig, rel_errors_bkg, significances in [
+                (hist_sig_train, hist_bkg_train, rel_errors_bkg_train, rel_errors_sig_train, significances_train),
+                (hist_sig_test, hist_bkg_test, rel_errors_bkg_test, rel_errors_sig_test, significances_test),
+        ]:
+            # first set nan values to 0 and multiply by lumi
+            for arr in hist_sig, hist_bkg, rel_errors_bkg:
+                arr[np.isnan(arr)] = 0
+            hist_sig *= lumifactor
+            hist_bkg *= lumifactor
+            for i in range(len(hist_sig)):
+                s = sum(hist_sig[i:])
+                b = sum(hist_bkg[i:])
+                db = math.sqrt(sum((rel_errors_bkg[i:]*hist_bkg[i:])**2))
+                if significanceFunction is None:
+                    try:
+                        z = s/math.sqrt(b+db**2)
+                    except (ZeroDivisionError, ValueError) as e:
+                        z = 0
+                else:
+                    z = significanceFunction(s, b, db)
+                logger.debug("s, b, db, z = {}, {}, {}, {}".format(s, b, db, z))
+                significances.append(z)
+
+        fig, ax = plt.subplots()
+        width = centers_sig_train[1]-centers_sig_train[0]
+        ax.plot(centers_bkg_train, significances_train, label="train")
+        ax.plot(centers_bkg_test, significances_test, label="test")
+        ax.set_xlabel("Cut on NN score")
+        ax.set_ylabel("Significance")
+        ax.legend(loc='lower center', framealpha=0.5)
+        fig.savefig(os.path.join(self.project_dir, "significances.pdf"))
+
 
     @property
     def csv_hist(self):