From 77dd63a536f674171e50e50802ac830a420946b2 Mon Sep 17 00:00:00 2001
From: Thomas Weber <Thomas.Weber@physik.uni-muenchen.de>
Date: Fri, 27 Apr 2018 11:28:07 +0200
Subject: [PATCH] Added plot functions for losses and accuracy

---
 toolkit.py | 34 ++++++++++++++++++++++++++++++----
 1 file changed, 30 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index f679f69..1f86123 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -44,7 +44,7 @@ class KerasROOTClassification:
 
     def __init__(self, name,
                  signal_trees, bkg_trees, branches, weight_expr, identifiers,
-                 layers=3, nodes=64, batch_size=128, activation_function='relu', out_dir="./outputs"):
+                 layers=3, nodes=64, batch_size=128, validation_split=0.33, activation_function='relu', out_dir="./outputs"):
         self.name = name
         self.signal_trees = signal_trees
         self.bkg_trees = bkg_trees
@@ -54,6 +54,7 @@ class KerasROOTClassification:
         self.layers = layers
         self.nodes = nodes
         self.batch_size = batch_size
+        self.validation_split = validation_split
         self.activation_function = activation_function
         self.out_dir = out_dir
 
@@ -82,6 +83,7 @@ class KerasROOTClassification:
         self._bkg_weights = None
         self._sig_weights = None
         self._model = None
+        self._history = None
 
         self.score_train = None
         self.score_test = None
@@ -280,10 +282,11 @@ class KerasROOTClassification:
         self.total_epochs = self._read_info("epochs", 0)
 
         logger.info("Train model")
-        self.model.fit(self.x_train,
+        self._history = self.model.fit(self.x_train,
                        # the reshape might be unnescessary here
                        self.y_train.reshape(-1, 1),
                        epochs=epochs,
+                       validation_split = self.validation_split,
                        class_weight=self.class_weight,
                        shuffle=True,
                        batch_size=self.batch_size)
@@ -376,7 +379,28 @@ class KerasROOTClassification:
     def plot_score(self):
         pass
 
-
+    
+    def plot_loss(self):
+
+        logger.info("Plot losses")
+        plt.plot(self._history.history['loss'])
+        plt.plot(self._history.history['val_loss'])
+        plt.ylabel('loss')
+        plt.xlabel('epoch')
+        plt.legend(['train','test'], loc='upper left')
+        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
+    
+
+    def plot_accuracy(self):
+        
+        logger.info("Plot accuracy")
+        plt.plot(self._history.history['acc'])
+        plt.plot(self._history.history['val_acc'])
+        plt.title('model accuracy')
+        plt.ylabel('accuracy')
+        plt.xlabel('epoch')
+        plt.legend(['train', 'test'], loc='upper left')
+        plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
 
 if __name__ == "__main__":
 
@@ -395,5 +419,7 @@ if __name__ == "__main__":
                                 weight_expr = "eventWeight*genWeight",
                                 identifiers = ["DatasetNumber", "EventNumber"])
 
-    c.train(epochs=1)
+    c.train(epochs=20)
     c.plot_ROC()
+    c.plot_loss()
+    c.plot_accuracy()
-- 
GitLab