Skip to content
Snippets Groups Projects
Commit 46cb66fa authored by Nikolai's avatar Nikolai
Browse files

put model checkpoint/reload weights into separate function

parent f14c65cd
No related branches found
No related tags found
No related merge requests found
......@@ -842,6 +842,11 @@ class ClassificationProject(object):
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
self.checkpoint_model(epochs)
def checkpoint_model(self, epochs):
logger.info("Save history")
self._dump_history()
......@@ -1455,6 +1460,8 @@ class ClassificationProjectRNN(ClassificationProject):
for branch_index, branch in enumerate(self.fields):
self.plot_input(branch_index)
self.total_epochs = self._read_info("epochs", 0)
try:
self.shuffle_training_data() # needed here too, in order to get correct validation data
self.is_training = True
......@@ -1468,8 +1475,8 @@ class ClassificationProjectRNN(ClassificationProject):
self.is_training = False
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
logger.info("Save history")
self._dump_history()
self.checkpoint_model(epochs)
def get_input_list(self, x):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment