diff --git a/toolkit.py b/toolkit.py
index b8ae99305951737f50d3d0712a333bb1449a7f19..df89ef3b0d36fbd4e4aab7ca34dcaca6322e370d 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -921,15 +921,18 @@ class ClassificationProject(object):
 
         significances_train = []
         significances_test = []
-        for hist_sig, hist_bkg, rel_errors_sig, rel_errors_bkg, significances in [
-                (hist_sig_train, hist_bkg_train, rel_errors_bkg_train, rel_errors_sig_train, significances_train),
-                (hist_sig_test, hist_bkg_test, rel_errors_bkg_test, rel_errors_sig_test, significances_test),
+        for hist_sig, hist_bkg, rel_errors_sig, rel_errors_bkg, significances, w, y in [
+                (hist_sig_train, hist_bkg_train, rel_errors_bkg_train, rel_errors_sig_train, significances_train, self.w_train, self.y_train),
+                (hist_sig_test, hist_bkg_test, rel_errors_bkg_test, rel_errors_sig_test, significances_test, self.w_test, self.y_test),
         ]:
+            # factor to rescale due to using only a fraction of events (training and test samples)
+            normfactor_sig = (np.sum(self.w_train[self.y_train==1])+np.sum(self.w_test[self.y_test==1]))/np.sum(w[y==1])
+            normfactor_bkg = (np.sum(self.w_train[self.y_train==0])+np.sum(self.w_test[self.y_test==0]))/np.sum(w[y==0])
             # first set nan values to 0 and multiply by lumi
             for arr in hist_sig, hist_bkg, rel_errors_bkg:
                 arr[np.isnan(arr)] = 0
-            hist_sig *= lumifactor
-            hist_bkg *= lumifactor
+            hist_sig *= lumifactor*normfactor_sig
+            hist_bkg *= lumifactor*normfactor_bkg
             for i in range(len(hist_sig)):
                 s = sum(hist_sig[i:])
                 b = sum(hist_bkg[i:])
@@ -941,6 +944,8 @@ class ClassificationProject(object):
                         z = 0
                 else:
                     z = significanceFunction(s, b, db)
+                if z == float('inf'):
+                    z = 0
                 logger.debug("s, b, db, z = {}, {}, {}, {}".format(s, b, db, z))
                 significances.append(z)