diff --git a/compare.py b/compare.py index 8b6409f3170c7d39c66ebb3efac922c651922137..4f2332de6b8a54898e6be5a17c9cbad6212aa0c2 100755 --- a/compare.py +++ b/compare.py @@ -21,7 +21,7 @@ def overlay_ROC(filename, *projects): fpr, tpr, threshold = roc_curve(p.y_test, p.scores_test, sample_weight = p.w_test) fpr = 1.0 - fpr roc_auc = auc(tpr, fpr) - + plt.grid(color='gray', linestyle='--', linewidth=1) plt.plot(tpr, fpr, label=str(p.name+" (AUC = {})".format(roc_auc))) @@ -41,9 +41,13 @@ def overlay_loss(filename, *projects): logger.info("Overlay loss curves for {}".format([p.name for p in projects])) - for p in projects: - plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name) - plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name) + prop_cycle = plt.rcParams['axes.prop_cycle'] + colors = prop_cycle.by_key()['color'] + + for p,color in zip(projects,colors): + plt.semilogy(p.history.history['loss'], linestyle='--', label="Training Loss "+p.name, color=color) + plt.semilogy(p.history.history['val_loss'], label="Validation Loss "+p.name, color=color) + plt.ylabel('loss') plt.xlabel('epoch') plt.legend(loc='upper right')