diff --git a/toolkit.py b/toolkit.py
index 9573f262c5e005f15e578e0d84ec9be621994143..191991f6c4a6b2fcdff6299187d5d40d65971917 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -198,8 +198,6 @@ class ClassificationProject(object):
 
         self._scaler = None
         self._class_weight = None
-        self._bkg_weights = None
-        self._sig_weights = None
         self._model = None
         self._history = None
         self._callbacks_list = []
@@ -476,8 +474,13 @@ class ClassificationProject(object):
         return self._class_weight
 
 
-    def load(self):
+    def load(self, reload=False):
         "Load all data needed for plotting and training"
+
+        if reload:
+            self.data_loaded = False
+            self.data_transformed = False
+
         if not self.data_loaded:
             self._load_data()
 
@@ -492,9 +495,10 @@ class ClassificationProject(object):
         np.random.shuffle(self.y_train)
         np.random.set_state(rn_state)
         np.random.shuffle(self.w_train)
-        if self._scores_test is not None:
+        if self._scores_train is not None:
+            logger.info("Shuffling scores, since they are also there")
             np.random.set_state(rn_state)
-            np.random.shuffle(self._scores_test)
+            np.random.shuffle(self._scores_train)
 
 
     def train(self, epochs=10):
@@ -531,12 +535,18 @@ class ClassificationProject(object):
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
 
+        logger.info("Reloading (and re-transforming) unshuffled training data")
+        self.load(reload=True)
+
         logger.info("Create/Update scores for ROC curve")
         self.scores_test = self.model.predict(self.x_test)
         self.scores_train = self.model.predict(self.x_train)
 
         self._dump_to_hdf5("scores_train", "scores_test")
 
+        logger.info("Creating all validation plots")
+        self.plot_all()
+
 
 
     def evaluate(self, x_eval):
@@ -587,34 +597,13 @@ class ClassificationProject(object):
         pass
 
 
-    @property
-    def bkg_weights(self):
-        """
-        class weights multiplied by event weights (for plotting)
-        TODO: find a better way to do this
-        """
-        if self._bkg_weights is None:
-            logger.debug("Calculating background weights for plotting")
-            self._bkg_weights = np.empty(sum(self.y_train == 0))
-            self._bkg_weights.fill(self.class_weight[0])
-            self._bkg_weights *= self.w_train[self.y_train == 0]
-            logger.debug("Background weights: {}".format(self._bkg_weights))
-        return self._bkg_weights
-
-
-    @property
-    def sig_weights(self):
-        """
-        class weights multiplied by event weights (for plotting)
-        TODO: find a better way to do this
-        """
-        if self._sig_weights is None:
-            logger.debug("Calculating signal weights for plotting")
-            self._sig_weights = np.empty(sum(self.y_train == 1))
-            self._sig_weights.fill(self.class_weight[1])
-            self._sig_weights *= self.w_train[self.y_train == 1]
-            logger.debug("Signal weights: {}".format(self._sig_weights))
-        return self._sig_weights
+    @staticmethod
+    def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
+        hist, bins = np.histogram(x, **np_kwargs)
+        centers = (bins[:-1] + bins[1:]) / 2
+        if scale_factor is not None:
+            hist *= scale_factor
+        return centers, hist
 
 
     def plot_input(self, var_index):
@@ -623,6 +612,8 @@ class ClassificationProject(object):
         fig, ax = plt.subplots()
         bkg = self.x_train[:,var_index][self.y_train == 0]
         sig = self.x_train[:,var_index][self.y_train == 1]
+        bkg_weights = self.w_train[self.y_train == 0]
+        sig_weights = self.w_train[self.y_train == 1]
 
         logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
         logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
@@ -636,14 +627,19 @@ class ClassificationProject(object):
         logger.debug("Calculated range based on percentiles: {}".format(plot_range))
 
         try:
-            ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
-            ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
+            centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
         except ValueError:
             # weird, probably not always working workaround for a numpy bug
             plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
             logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
-            ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
-            ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
+            centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
+
+        width = centers_sig[1]-centers_sig[0]
+        ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
+        ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
+
         ax.set_xlabel(branch+" (transformed)")
         plot_dir = os.path.join(self.project_dir, "plots")
         if not os.path.exists(plot_dir):
@@ -685,8 +681,24 @@ class ClassificationProject(object):
         plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
         plt.clf()
 
+
     def plot_score(self):
-        pass
+        plot_opts = dict(bins=50, range=(0, 1))
+        centers_sig_train, hist_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
+        centers_bkg_train, hist_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
+        centers_sig_test, hist_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
+        centers_bkg_test, hist_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
+        fig, ax = plt.subplots()
+        width = centers_sig_train[1]-centers_sig_train[0]
+        ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
+        ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
+        ax.scatter(centers_bkg_test, hist_bkg_test, color="b", label="background test")
+        ax.scatter(centers_sig_test, hist_sig_test, color="r", label="signal test")
+        ax.set_yscale("log")
+        ax.set_xlabel("NN output")
+        plt.legend(loc='upper right', framealpha=1.0)
+        fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
+
 
 
     @property
@@ -743,6 +755,15 @@ class ClassificationProject(object):
         plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
         plt.clf()
 
+
+    def plot_all(self):
+        self.plot_ROC()
+        self.plot_accuracy()
+        self.plot_loss()
+        self.plot_score()
+        self.plot_weights()
+
+
 def create_getter(dataset_name):
     def getx(self):
         if getattr(self, "_"+dataset_name) is None:
@@ -789,9 +810,7 @@ if __name__ == "__main__":
 
     np.random.seed(42)
     c.train(epochs=20)
-    c.plot_ROC()
-    c.plot_loss()
-    c.plot_accuracy()
+    #c.plot_all()
 
     # c.write_friend_tree("test4_score",
     #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",