From 133b7c966ac7145a80eea053fe10bcea8befd247 Mon Sep 17 00:00:00 2001 From: Eric Schanet <eric.schanet@cern.ch> Date: Thu, 3 May 2018 16:01:43 +0200 Subject: [PATCH] Same color for test and validation loss function --- compare.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/compare.py b/compare.py index 8b6409f..4f2332d 100755 --- a/compare.py +++ b/compare.py @@ -21,7 +21,7 @@ def overlay_ROC(filename, *projects): fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test) fpr = 1.0 - fpr roc_auc = auc(tpr, fpr) - + plt.grid(color='gray', linestyle='--', linewidth=1) plt.plot(tpr, fpr, label=str(p.name+" (AUC = {})".format(roc_auc))) @@ -41,9 +41,13 @@ def overlay_loss(filename, *projects): logger.info("Overlay loss curves for {}".format([p.name for p in projects])) - for p in projects: - plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name) - plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name) + prop_cycle = plt.rcParams['axes.prop_cycle'] + colors = prop_cycle.by_key()['color'] + + for p,color in zip(projects,colors): + plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name, color=color) + plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name, color=color) + plt.ylabel('loss') plt.xlabel('epoch') plt.legend(loc='upper right') -- GitLab