From 133b7c966ac7145a80eea053fe10bcea8befd247 Mon Sep 17 00:00:00 2001
From: Eric Schanet <eric.schanet@cern.ch>
Date: Thu, 3 May 2018 16:01:43 +0200
Subject: [PATCH] Same color for test and validation loss function

---
 compare.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/compare.py b/compare.py
index 8b6409f..4f2332d 100755
--- a/compare.py
+++ b/compare.py
@@ -21,7 +21,7 @@ def overlay_ROC(filename, *projects):
         fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test)
         fpr = 1.0 - fpr
         roc_auc = auc(tpr, fpr)
-        
+
         plt.grid(color='gray', linestyle='--', linewidth=1)
         plt.plot(tpr,  fpr, label=str(p.name+" (AUC = {})".format(roc_auc)))
 
@@ -41,9 +41,13 @@ def overlay_loss(filename, *projects):
 
     logger.info("Overlay loss curves for {}".format([p.name for p in projects]))
 
-    for p in projects:
-        plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name)
-        plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name)
+    prop_cycle = plt.rcParams['axes.prop_cycle']
+    colors = prop_cycle.by_key()['color']
+
+    for p,color in zip(projects,colors):
+        plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name, color=color)
+        plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name, color=color)
+        
     plt.ylabel('loss')
     plt.xlabel('epoch')
     plt.legend(loc='upper right')
-- 
GitLab