diff --git a/toolkit.py b/toolkit.py
index 7583548fce9ef49593831299bc1a30051de8a2a5..d198121f4fe7f5a9cb944ad3e577c0528914aa9f 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -338,6 +338,8 @@ class ClassificationProject(object):
 
         self._fields = None
 
+        self._mean_train_weight = None
+
 
     @property
     def fields(self):
@@ -741,6 +743,13 @@ class ClassificationProject(object):
             np.random.shuffle(self._scores_train)
 
 
+    @property
+    def mean_train_weight(self):
+        if self._mean_train_weight is None:
+            self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)])
+        return self._mean_train_weight
+
+
     @property
     def w_validation(self):
         "class weighted validation data weights"
@@ -749,7 +758,7 @@ class ClassificationProject(object):
             self._w_validation = np.array(self.w_train[split_index:])
             self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
             self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
-        return self._w_validation
+        return self._w_validation/self.mean_train_weight
 
 
     @property
@@ -819,7 +828,7 @@ class ClassificationProject(object):
                                validation_split = self.validation_split,
                                # we have to multiply by class weight since keras ignores class weight if sample weight is given
                                # see https://github.com/keras-team/keras/issues/497
-                               sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)],
+                               sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]/self.mean_train_weight,
                                shuffle=True,
                                batch_size=self.batch_size,
                                callbacks=self.callbacks_list)
@@ -842,6 +851,11 @@ class ClassificationProject(object):
             except KeyboardInterrupt:
                 logger.info("Interrupt training - continue with rest")
 
+        self.checkpoint_model(epochs)
+
+
+    def checkpoint_model(self, epochs):
+
         logger.info("Save history")
         self._dump_history()
 
@@ -1455,6 +1469,8 @@ class ClassificationProjectRNN(ClassificationProject):
         for branch_index, branch in enumerate(self.fields):
             self.plot_input(branch_index)
 
+        self.total_epochs = self._read_info("epochs", 0)
+
         try:
             self.shuffle_training_data() # needed here too, in order to get correct validation data
             self.is_training = True
@@ -1468,8 +1484,8 @@ class ClassificationProjectRNN(ClassificationProject):
             self.is_training = False
         except KeyboardInterrupt:
             logger.info("Interrupt training - continue with rest")
-        logger.info("Save history")
-        self._dump_history()
+
+        self.checkpoint_model(epochs)
 
 
     def get_input_list(self, x):
@@ -1494,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject):
                 x_input = self.get_input_list(x_batch)
                 yield (x_input,
                        y_train[shuffled_idx[start:start+int(self.batch_size)]],
-                       w_batch*np.array(self.class_weight)[y_batch.astype(int)])
+                       w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight)
 
 
     @property