diff --git a/toolkit.py b/toolkit.py index 3416ef4278d6188ba90173952a7cec53ee98b996..5edaad79ff0180440c82c09b844d025a7ce571b6 100755 --- a/toolkit.py +++ b/toolkit.py @@ -391,10 +391,13 @@ class ClassificationProject(object): if self._history is None: self._history = History() if os.path.exists(params_file) and os.path.exists(history_file): - with open(params_file) as f: - self._history.params = json.load(f) - with open(history_file) as f: - self._history.history = json.load(f) + try: + with open(params_file) as f: + self._history.params = json.load(f) + with open(history_file) as f: + self._history.history = json.load(f) + except ValueError: + logger.warning("Couldn't load history - starting with empty one") return self._history @@ -765,6 +768,10 @@ class ClassificationProject(object): else: hist_dict = self.history.history + if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict): + logger.warning("No previous history found for plotting, try global history") + hist_dict = self.csv_hist + logger.info("Plot losses") plt.plot(hist_dict['loss']) plt.plot(hist_dict['val_loss']) @@ -787,6 +794,10 @@ class ClassificationProject(object): else: hist_dict = self.history.history + if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict): + logger.warning("No previous history found for plotting, try global history") + hist_dict = self.csv_hist + logger.info("Plot accuracy") plt.plot(hist_dict['acc']) plt.plot(hist_dict['val_acc'])