diff --git a/toolkit.py b/toolkit.py
index 3e7b5743ab9e0f5aa41ea7312599735ac35d83a7..e25b458e90c62f4e278f31f7a3793565220487e3 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -330,8 +330,8 @@ class ClassificationProject(object):
         self._scores_train = None
         self._scores_test = None
 
-        # class weighted validation data
-        self._w_validation = None
+        # class weighted training data (divided by mean)
+        self._w_train_tot = None
 
         self._s_eventlist_train = None
         self._b_eventlist_train = None
@@ -354,8 +354,6 @@ class ClassificationProject(object):
 
         self._fields = None
 
-        self._mean_train_weight = None
-
 
     @property
     def fields(self):
@@ -549,7 +547,7 @@ class ClassificationProject(object):
                     self._scaler = RobustScaler()
                 elif self.scaler_type == "WeightedRobustScaler":
                     self._scaler = WeightedRobustScaler()
-                    scaler_fit_kwargs["weights"] = self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]
+                    scaler_fit_kwargs["weights"] = self.w_train_tot
                 else:
                     raise ValueError("Scaler type {} unknown".format(self.scaler_type))
                 logger.info("Fitting {} to training data".format(self.scaler_type))
@@ -760,6 +758,8 @@ class ClassificationProject(object):
         np.random.shuffle(self.y_train)
         np.random.set_state(rn_state)
         np.random.shuffle(self.w_train)
+        np.random.set_state(rn_state)
+        np.random.shuffle(self.w_train_tot)
         if self._scores_train is not None:
             logger.info("Shuffling scores, since they are also there")
             np.random.set_state(rn_state)
@@ -767,35 +767,30 @@ class ClassificationProject(object):
 
 
     @property
-    def mean_train_weight(self):
-        if self._mean_train_weight is None:
-            self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)])
-        return self._mean_train_weight
-
-
-    @property
-    def w_validation(self):
-        "class weighted validation data weights"
-        split_index = int((1-self.validation_split)*len(self.x_train))
-        if self._w_validation is None:
-            self._w_validation = np.array(self.w_train[split_index:])
-            self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
-            self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
-        return self._w_validation/self.mean_train_weight
+    def w_train_tot(self):
+        "(sample weight * class weight), divided by mean"
+        if not self.balance_dataset:
+            class_weight = self.class_weight
+        else:
+            class_weight = self.balanced_class_weight
+        if self._w_train_tot is None:
+            self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)]
+            #self._w_train_tot /= np.mean(self._w_train_tot)
+        return self._w_train_tot
 
 
     @property
-    def class_weighted_validation_data(self):
-        "class weighted validation data. Attention: Shuffle training data before using this!"
+    def validation_data(self):
+        "Validation data. Attention: Shuffle training data before using this!"
         split_index = int((1-self.validation_split)*len(self.x_train))
-        return self.x_train[split_index:], self.y_train[split_index:], self.w_validation
+        return self.x_train[split_index:], self.y_train[split_index:], self.w_train_tot[split_index:]
 
 
     @property
     def training_data(self):
-        "training data with validation data split off. Attention: Shuffle training data before using this!"
+        "Training data with validation data split off. Attention: Shuffle training data before using this!"
         split_index = int((1-self.validation_split)*len(self.x_train))
-        return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
+        return self.x_train[:split_index], self.y_train[:split_index], self.w_train_tot[:split_index]
 
 
     def yield_single_class_batch(self, class_label):
@@ -812,7 +807,7 @@ class ClassificationProject(object):
             for start in range(0, len(shuffled_idx), int(self.batch_size/2)):
                 yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]],
                        y_train[shuffled_idx[start:start+int(self.batch_size/2)]],
-                       w_train[shuffled_idx[start:start+int(self.batch_size/2)]]*self.balanced_class_weight[class_label])
+                       w_train[shuffled_idx[start:start+int(self.batch_size/2)]])
 
 
     def yield_balanced_batch(self):
@@ -834,6 +829,8 @@ class ClassificationProject(object):
 
         self.load()
 
+        self.shuffle_training_data()
+
         for branch_index, branch in enumerate(self.fields):
             self.plot_input(branch_index)
 
@@ -842,16 +839,15 @@ class ClassificationProject(object):
         logger.info("Train model")
         if not self.balance_dataset:
             try:
-                self.shuffle_training_data()
                 self.is_training = True
                 self.model.fit(self.x_train,
                                # the reshape might be unnescessary here
                                self.y_train.reshape(-1, 1),
                                epochs=epochs,
-                               validation_split = self.validation_split,
+                               validation_split=self.validation_split,
                                # we have to multiply by class weight since keras ignores class weight if sample weight is given
                                # see https://github.com/keras-team/keras/issues/497
-                               sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]/self.mean_train_weight,
+                               sample_weight=self.w_train_tot,
                                shuffle=True,
                                batch_size=self.batch_size,
                                callbacks=self.callbacks_list)
@@ -860,7 +856,6 @@ class ClassificationProject(object):
                 logger.info("Interrupt training - continue with rest")
         else:
             try:
-                self.shuffle_training_data() # needed here too, in order to get correct validation data
                 self.is_training = True
                 labels, label_counts = np.unique(self.y_train, return_counts=True)
                 logger.info("Training on balanced batches")
@@ -868,7 +863,7 @@ class ClassificationProject(object):
                 self.model.fit_generator(self.yield_balanced_batch(),
                                          steps_per_epoch=int(min(label_counts)/self.batch_size),
                                          epochs=epochs,
-                                         validation_data=self.class_weighted_validation_data,
+                                         validation_data=self.validation_data,
                                          callbacks=self.callbacks_list)
                 self.is_training = False
             except KeyboardInterrupt:
@@ -1012,33 +1007,39 @@ class ClassificationProject(object):
         fig, ax = plt.subplots()
         bkg = self.x_train[:,var_index][self.y_train == 0]
         sig = self.x_train[:,var_index][self.y_train == 1]
-        bkg_weights = self.w_train[self.y_train == 0]
-        sig_weights = self.w_train[self.y_train == 1]
+        bkg_weights = self.w_train_tot[self.y_train == 0]
+        sig_weights = self.w_train_tot[self.y_train == 1]
+
+        if self.balance_dataset:
+            if len(sig) < len(bkg):
+                logger.warning("Plotting only up to {} bkg events, since we use balance_dataset".format(len(sig)))
+                bkg = bkg[0:len(sig)]
+                bkg_weights = bkg_weights[0:len(sig)]
+            else:
+                logger.warning("Plotting only up to {} sig events, since we use balance_dataset".format(len(bkg)))
+                sig = sig[0:len(bkg)]
+                sig_weights = sig_weights[0:len(bkg)]
 
         logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
         logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
 
         # calculate percentiles to get a heuristic for the range to be plotted
-        # (should in principle also be done with weights, but for now do it unweighted)
-        # range_sig = np.percentile(sig, [1, 99])
-        # range_bkg = np.percentile(sig, [1, 99])
-        # plot_range = (min(range_sig[0], range_bkg[0]), max(range_sig[1], range_sig[1]))
         plot_range = weighted_quantile(
             self.x_train[:,var_index], [0.01, 0.99],
-            sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]
+            sample_weight=self.w_train_tot
         )
 
         logger.debug("Calculated range based on percentiles: {}".format(plot_range))
 
         try:
-            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
-            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
+            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=50, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=50, range=plot_range, weights=bkg_weights)
         except ValueError:
             # weird, probably not always working workaround for a numpy bug
             plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1])))
             logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
-            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
-            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
+            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=50, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=50, range=plot_range, weights=bkg_weights)
 
         width = centers_sig[1]-centers_sig[0]
         ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
@@ -1600,7 +1601,7 @@ class ClassificationProjectRNN(ClassificationProject):
             self.model.fit_generator(self.yield_batch(),
                                      steps_per_epoch=int(len(self.training_data[0])/self.batch_size),
                                      epochs=epochs,
-                                     validation_data=self.class_weighted_validation_data,
+                                     validation_data=self.validation_data,
                                      callbacks=self.callbacks_list)
             self.is_training = False
         except KeyboardInterrupt:
@@ -1635,9 +1636,9 @@ class ClassificationProjectRNN(ClassificationProject):
 
 
     @property
-    def class_weighted_validation_data(self):
+    def validation_data(self):
         "class weighted validation data. Attention: Shuffle training data before using this!"
-        x_val, y_val, w_val = super(ClassificationProjectRNN, self).class_weighted_validation_data
+        x_val, y_val, w_val = super(ClassificationProjectRNN, self).validation_data
         x_val_input = self.get_input_list(x_val)
         return x_val_input, y_val, w_val