diff --git a/toolkit.py b/toolkit.py
index d124ef919d911dd95e9dd7bfadd28dab09e3b40d..fa1c8935593eca402b7db31607cbb0b632d52552 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -179,14 +179,18 @@ class ClassificationProject(object):
 
     :param normalize_weights: normalize the weights to mean 1
 
+    :param planing_vars: variables in which a binwise reweighting
+                         should be performed, such that the distribution becomes flat
+                         ("Data planing"). Pass a tuple of (expr, bins, range)
+
     """
 
 
     # Datasets that are stored to (and dynamically loaded from) hdf5
-    dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test"]
+    dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "scores_train", "scores_test", "planing_array"]
 
     # Datasets that are retrieved from ROOT trees the first time
-    dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]
+    dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test", "planing_array"]
 
     def __init__(self, name, *args, **kwargs):
         if len(args) < 1 and len(kwargs) < 1:
@@ -245,7 +249,9 @@ class ClassificationProject(object):
                         loss='binary_crossentropy',
                         mask_value=None,
                         apply_class_weight=True,
-                        normalize_weights=True):
+                        normalize_weights=True,
+                        planing_vars=(None, None, None),
+    ):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -321,6 +327,8 @@ class ClassificationProject(object):
         self.apply_class_weight = apply_class_weight
         self.normalize_weights = normalize_weights
 
+        self.planing_var, self.planing_bins, self.planing_range = planing_vars
+
         self.s_train = None
         self.b_train = None
         self.s_test = None
@@ -334,10 +342,14 @@ class ClassificationProject(object):
         self._w_test = None
         self._scores_train = None
         self._scores_test = None
+        self._planing_array = None
 
         # class weighted training data (divided by mean)
         self._w_train_tot = None
 
+        # planing weights in case requested
+        self._w_train_plane = None
+
         self._s_eventlist_train = None
         self._b_eventlist_train = None
 
@@ -400,20 +412,27 @@ class ClassificationProject(object):
                 signal_chain.AddFile(filename, -1, treename)
             for filename, treename in self.bkg_trees:
                 bkg_chain.AddFile(filename, -1, treename)
+
+            branches = self.branches
+            if self.planing_var is not None:
+                branches.append(self.planing_var)
+
+            print(branches)
+
             self.s_train = tree2array(signal_chain,
-                                      branches=self.branches+[self.weight_expr]+self.identifiers,
+                                      branches=branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
                                       start=0, step=self.step_signal, stop=self.stop_train)
             self.b_train = tree2array(bkg_chain,
-                                      branches=self.branches+[self.weight_expr]+self.identifiers,
+                                      branches=branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
                                       start=0, step=self.step_bkg, stop=self.stop_train)
             self.s_test = tree2array(signal_chain,
-                                     branches=self.branches+[self.weight_expr],
+                                     branches=branches+[self.weight_expr],
                                      selection=self.selection,
                                      start=1, step=self.step_signal, stop=self.stop_test)
             self.b_test = tree2array(bkg_chain,
-                                     branches=self.branches+[self.weight_expr],
+                                     branches=branches+[self.weight_expr],
                                      selection=self.selection,
                                      start=1, step=self.step_bkg, stop=self.stop_test)
 
@@ -426,6 +445,10 @@ class ClassificationProject(object):
             self.b_eventlist_train = self.b_train[self.identifiers].astype(dtype=[(branchName, "u8") for branchName in self.identifiers])
             self._dump_training_list()
 
+            # store planing branch
+            if self.planing_var is not None:
+                self.planing_array = np.concatenate([self.s_train[self.planing_var], self.b_train[self.planing_var]])
+
             # now we don't need the identifiers anymore
             self.s_train = self.s_train[self.fields+[self.weight_expr]]
             self.b_train = self.b_train[self.fields+[self.weight_expr]]
@@ -821,6 +844,49 @@ class ClassificationProject(object):
         self.data_shuffled = True
 
 
+    @property
+    def w_train_plane(self):
+        """
+        weight that reweights in a requested distribution, such that
+        it becomes flat for both signal and background (information
+        effectively removed)
+        """
+        if self._w_train_plane is None and self.planing_array is not None:
+            self._w_train_plane = np.empty(len(self.x_train), dtype=float)
+            for class_label in [0, 1]:
+                ar = self.planing_array[self.y_train==class_label]
+                hist, edges = np.histogram(
+                    ar,
+                    bins=self.planing_bins,
+                    range=self.planing_range,
+                    weights=self.get_total_weight()[self.y_train==class_label],
+                )
+                sfs = 1./hist
+                sfs[np.isinf(sfs)] = 0
+                sfs = np.concatenate([sfs, [0]]) # overflow is reweighted to 0
+                bin_idx = np.digitize(ar, bins)
+                bin_inds -= 1 # different convention for digitize and histogram?
+                self._w_train_plane[self.y_train==class_label] = sfs[bin_inds]
+        return self._w_train_plane
+
+
+    def get_total_weight(self):
+        "(sample weight * class weight), divided by mean (for training)"
+        if not self.balance_dataset:
+            class_weight = self.class_weight
+        else:
+            class_weight = self.balanced_class_weight
+        if not self.data_loaded:
+            raise ValueError("Data not loaded! can't calculate total weight")
+        if self.apply_class_weight:
+            w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
+        else:
+            w_train_tot = np.array(self.w_train)
+        if self.normalize_weights:
+            w_train_tot /= np.mean(w_train_tot)
+        return w_train_tot
+
+
     @property
     def w_train_tot(self):
         "(sample weight * class weight), divided by mean"
@@ -831,12 +897,9 @@ class ClassificationProject(object):
         if not self.data_loaded:
             raise ValueError("Data not loaded! can't calculate total weight")
         if self._w_train_tot is None:
-            if self.apply_class_weight:
-                self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
-            else:
-                self._w_train_tot = np.array(self.w_train)
-            if self.normalize_weights:
-                self._w_train_tot /= np.mean(self._w_train_tot)
+            self._w_train_tot = self.get_total_weight()
+            if self.w_train_plane is not None:
+                self._w_train_tot *= self._w_train_plane
         return self._w_train_tot