diff --git a/toolkit.py b/toolkit.py
index 01b17e77ba2f28a30a48dbda3c9a36771b07f5e7..dbec934f105a28f16fcb0d20b43f42133310b599 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -27,7 +27,7 @@ from sklearn.metrics import roc_curve, auc
 from keras.models import Sequential
 from keras.layers import Dense, Dropout
 from keras.models import model_from_json
-from keras.callbacks import History, EarlyStopping, CSVLogger
+from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint
 from keras.optimizers import SGD
 import keras.optimizers
 
@@ -157,6 +157,7 @@ class ClassificationProject(object):
                         optimizer_opts=None,
                         use_earlystopping=True,
                         earlystopping_opts=None,
+                        use_modelcheckpoint=True,
                         random_seed=1234):
 
         self.name = name
@@ -178,6 +179,7 @@ class ClassificationProject(object):
         self.step_bkg = step_bkg
         self.optimizer = optimizer
         self.use_earlystopping = use_earlystopping
+        self.use_modelcheckpoint = use_modelcheckpoint
         if optimizer_opts is None:
             optimizer_opts = dict()
         self.optimizer_opts = optimizer_opts
@@ -354,6 +356,10 @@ class ClassificationProject(object):
         self._callbacks_list.append(self.history)
         if self.use_earlystopping:
             self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
+        if self.use_modelcheckpoint:
+            self._callbacks_list.append(ModelCheckpoint(save_best_only=True,
+                                                        verbose=True,
+                                                        filepath=os.path.join(self.project_dir, "weights.h5")))
         self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
         return self._callbacks_list
 
@@ -574,8 +580,9 @@ class ClassificationProject(object):
         logger.info("Save history")
         self._dump_history()
 
-        logger.info("Save weights")
-        self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
+        if not self.use_modelcheckpoint:
+            logger.info("Save weights")
+            self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
 
         self.total_epochs += epochs
         self._write_info("epochs", self.total_epochs)