From 47ef6c18800c58214c6fe8bb55d3d8afb627d49f Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Wed, 9 May 2018 11:43:17 +0200
Subject: [PATCH] Adding protection for missing history in case training was
 aborted during first epoch

---
 toolkit.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 3416ef4..5edaad7 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -391,10 +391,13 @@ class ClassificationProject(object):
         if self._history is None:
             self._history = History()
             if os.path.exists(params_file) and os.path.exists(history_file):
-                with open(params_file) as f:
-                    self._history.params = json.load(f)
-                with open(history_file) as f:
-                    self._history.history = json.load(f)
+                try:
+                    with open(params_file) as f:
+                        self._history.params = json.load(f)
+                    with open(history_file) as f:
+                        self._history.history = json.load(f)
+                except ValueError:
+                    logger.warning("Couldn't load history - starting with empty one")
         return self._history
 
 
@@ -765,6 +768,10 @@ class ClassificationProject(object):
         else:
             hist_dict = self.history.history
 
+        if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict):
+            logger.warning("No previous history found for plotting, try global history")
+            hist_dict = self.csv_hist
+
         logger.info("Plot losses")
         plt.plot(hist_dict['loss'])
         plt.plot(hist_dict['val_loss'])
@@ -787,6 +794,10 @@ class ClassificationProject(object):
         else:
             hist_dict = self.history.history
 
+        if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict):
+            logger.warning("No previous history found for plotting, try global history")
+            hist_dict = self.csv_hist
+
         logger.info("Plot accuracy")
         plt.plot(hist_dict['acc'])
         plt.plot(hist_dict['val_acc'])
-- 
GitLab