From 660b4beb75b7ef29782ae946647d49e0dc660571 Mon Sep 17 00:00:00 2001
From: Thomas Weber <Thomas.Weber@physik.uni-muenchen.de>
Date: Thu, 26 Apr 2018 20:07:32 +0200
Subject: [PATCH] Added plotROC function and scaler function call

---
 toolkit.py | 46 ++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 42 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 6d86077..7b5892b 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -13,11 +13,14 @@ import pandas as pd
 import h5py
 from sklearn.preprocessing import StandardScaler
 from sklearn.externals import joblib
+from sklearn.metrics import roc_curve
 
 from keras.models import Sequential
 from keras.layers import Dense
 from keras.models import model_from_json
 
+import matplotlib.pyplot as plt
+
 # configure number of cores
 # this doesn't seem to work, but at least with these settings keras only uses 4 processes
 import tensorflow as tf
@@ -77,6 +80,9 @@ class KerasROOTClassification:
         self._class_weight = None
         self._model = None
 
+        self.score_train = None
+        self.score_test = None
+
         # track the number of epochs this model has been trained
         self.total_epochs = 0
 
@@ -173,6 +179,8 @@ class KerasROOTClassification:
                 self._scaler = StandardScaler()
                 logger.info("Fitting StandardScaler to training data")
                 self._scaler.fit(self.x_train)
+                logger.info("Fitting StandardScaler to test data")
+                self._scaler.fit(self.x_test)
                 joblib.dump(self._scaler, filename)
         return self._scaler
 
@@ -211,7 +219,8 @@ class KerasROOTClassification:
                 self._model.add(Dense(self.nodes, activation=self.activation_function))
             # last layer is one neuron (binary classification)
             self._model.add(Dense(1, activation='sigmoid'))
-
+            
+            logger.info("Compile model")
             self._model.compile(optimizer='SGD',
                   loss='binary_crossentropy',
                   metrics=['accuracy'])
@@ -234,6 +243,8 @@ class KerasROOTClassification:
 
         if not self.data_loaded:
             self._load_data()
+        
+        self.scaler
 
         try:
             self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
@@ -243,17 +254,25 @@ class KerasROOTClassification:
             logger.info("No weights found, starting completely new training")
 
         self.total_epochs = self._read_info("epochs", 0)
-
+        
+        logger.info("Train model")
         self.model.fit(self.x_train, self.y_train,
                        epochs=epochs,
                        class_weight=self.class_weight,
                        shuffle=True,
                        batch_size=self.batch_size)
-
+        
+        logger.info("Save weights")
         self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
 
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)
+       
+        logger.info("Create scores for ROC curve") 
+        self.scores_test = self.model.predict(self.x_test)
+        self.scores_train = self.model.predict(self.x_train)
+
+
 
     def evaluate(self):
         pass
@@ -262,7 +281,25 @@ class KerasROOTClassification:
         pass
 
     def plotROC(self):
-        pass
+
+        logger.info("Plot ROC curve")
+        fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
+
+        plt.grid(color='gray', linestyle='--', linewidth=1)
+        plt.plot(fpr, tpr, label='NN')
+        plt.plot([0,1],[0,1], linestyle='--', color='black', label='Luck')
+        plt.xlabel("False positive rate (background rejection)")
+        plt.ylabel("True positive rate (signal efficiency)")
+        plt.title('Receiver operating characteristic')
+        plt.xlim(0,1)
+        plt.ylim(0,1)
+        plt.xticks(np.arange(0,1,0.1))
+        plt.yticks(np.arange(0,1,0.1))
+        plt.legend(loc='lower left', framealpha=1.0)
+
+        plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
+        plt.clf()
+
 
     def plotScore(self):
         pass
@@ -286,3 +323,4 @@ if __name__ == "__main__":
                                 identifiers = ["DatasetNumber", "EventNumber"])
 
     c.train(epochs=1)
+    c.plotROC()
-- 
GitLab