From a35af907ce85caf367905f5be4f47c47a0542ae8 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 8 May 2018 19:32:23 +0200
Subject: [PATCH] Plot score for training and test data, signal and background

---
 toolkit.py | 48 ++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 42 insertions(+), 6 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 205b3ca..a29b3af 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -472,8 +472,13 @@ class KerasROOTClassification(object):
         return self._class_weight
 
 
-    def load(self):
+    def load(self, reload=False):
         "Load all data needed for plotting and training"
+
+        if reload:
+            self.data_loaded = False
+            self.data_transformed = False
+
         if not self.data_loaded:
             self._load_data()
 
@@ -489,6 +494,7 @@ class KerasROOTClassification(object):
         np.random.set_state(rn_state)
         np.random.shuffle(self.w_train)
         if self._scores_train is not None:
+            logger.info("Shuffling scores, since they are also there")
             np.random.set_state(rn_state)
             np.random.shuffle(self._scores_train)
 
@@ -528,12 +534,18 @@ class KerasROOTClassification(object):
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
 
+        logger.info("Reloading (and re-transforming) unshuffled training data")
+        self.load(reload=True)
+
         logger.info("Create/Update scores for ROC curve")
         self.scores_test = self.model.predict(self.x_test)
         self.scores_train = self.model.predict(self.x_train)
 
         self._dump_to_hdf5("scores_train", "scores_test")
 
+        logger.info("Creating all validation plots")
+        self.plot_all()
+
 
 
     def evaluate(self, x_eval):
@@ -588,7 +600,8 @@ class KerasROOTClassification(object):
     def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
         hist, bins = np.histogram(x, **np_kwargs)
         centers = (bins[:-1] + bins[1:]) / 2
-        hist *= scale_factor
+        if scale_factor is not None:
+            hist *= scale_factor
         return centers, hist
 
 
@@ -667,8 +680,24 @@ class KerasROOTClassification(object):
         plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
         plt.clf()
 
+
     def plot_score(self):
-        pass
+        plot_opts = dict(bins=50, range=(0, 1))
+        centers_sig_train, hist_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
+        centers_bkg_train, hist_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
+        centers_sig_test, hist_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
+        centers_bkg_test, hist_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
+        fig, ax = plt.subplots()
+        width = centers_sig_train[1]-centers_sig_train[0]
+        ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
+        ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
+        ax.scatter(centers_bkg_test, hist_bkg_test, color="b", label="background test")
+        ax.scatter(centers_sig_test, hist_sig_test, color="r", label="signal test")
+        ax.set_yscale("log")
+        ax.set_xlabel("NN output")
+        plt.legend(loc='upper right', framealpha=1.0)
+        fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
+
 
 
     def plot_loss(self):
@@ -695,6 +724,15 @@ class KerasROOTClassification(object):
         plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
         plt.clf()
 
+
+    def plot_all(self):
+        self.plot_ROC()
+        self.plot_accuracy()
+        self.plot_loss()
+        self.plot_score()
+        self.plot_weights()
+
+
 def create_getter(dataset_name):
     def getx(self):
         if getattr(self, "_"+dataset_name) is None:
@@ -739,9 +777,7 @@ if __name__ == "__main__":
 
     np.random.seed(42)
     c.train(epochs=20)
-    c.plot_ROC()
-    c.plot_loss()
-    c.plot_accuracy()
+    #c.plot_all()
 
     # c.write_friend_tree("test4_score",
     #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
-- 
GitLab