From 47ef6c18800c58214c6fe8bb55d3d8afb627d49f Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Wed, 9 May 2018 11:43:17 +0200 Subject: [PATCH] Adding protection for missing history in case training was aborted during first epoch --- toolkit.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/toolkit.py b/toolkit.py index 3416ef4..5edaad7 100755 --- a/toolkit.py +++ b/toolkit.py @@ -391,10 +391,13 @@ class ClassificationProject(object): if self._history is None: self._history = History() if os.path.exists(params_file) and os.path.exists(history_file): - with open(params_file) as f: - self._history.params = json.load(f) - with open(history_file) as f: - self._history.history = json.load(f) + try: + with open(params_file) as f: + self._history.params = json.load(f) + with open(history_file) as f: + self._history.history = json.load(f) + except ValueError: + logger.warning("Couldn't load history - starting with empty one") return self._history @@ -765,6 +768,10 @@ class ClassificationProject(object): else: hist_dict = self.history.history + if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict): + logger.warning("No previous history found for plotting, try global history") + hist_dict = self.csv_hist + logger.info("Plot losses") plt.plot(hist_dict['loss']) plt.plot(hist_dict['val_loss']) @@ -787,6 +794,10 @@ class ClassificationProject(object): else: hist_dict = self.history.history + if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict): + logger.warning("No previous history found for plotting, try global history") + hist_dict = self.csv_hist + logger.info("Plot accuracy") plt.plot(hist_dict['acc']) plt.plot(hist_dict['val_acc']) -- GitLab