diff --git a/toolkit.py b/toolkit.py index 7583548fce9ef49593831299bc1a30051de8a2a5..d198121f4fe7f5a9cb944ad3e577c0528914aa9f 100755 --- a/toolkit.py +++ b/toolkit.py @@ -338,6 +338,8 @@ class ClassificationProject(object): self._fields = None + self._mean_train_weight = None + @property def fields(self): @@ -741,6 +743,13 @@ class ClassificationProject(object): np.random.shuffle(self._scores_train) + @property + def mean_train_weight(self): + if self._mean_train_weight is None: + self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]) + return self._mean_train_weight + + @property def w_validation(self): "class weighted validation data weights" @@ -749,7 +758,7 @@ class ClassificationProject(object): self._w_validation = np.array(self.w_train[split_index:]) self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0] self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1] - return self._w_validation + return self._w_validation/self.mean_train_weight @property @@ -819,7 +828,7 @@ class ClassificationProject(object): validation_split = self.validation_split, # we have to multiply by class weight since keras ignores class weight if sample weight is given # see https://github.com/keras-team/keras/issues/497 - sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)], + sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]/self.mean_train_weight, shuffle=True, batch_size=self.batch_size, callbacks=self.callbacks_list) @@ -842,6 +851,11 @@ class ClassificationProject(object): except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") + self.checkpoint_model(epochs) + + + def checkpoint_model(self, epochs): + logger.info("Save history") self._dump_history() @@ -1455,6 +1469,8 @@ class ClassificationProjectRNN(ClassificationProject): for branch_index, branch in enumerate(self.fields): self.plot_input(branch_index) + self.total_epochs = self._read_info("epochs", 0) + try: self.shuffle_training_data() # needed here too, in order to get correct validation data self.is_training = True @@ -1468,8 +1484,8 @@ class ClassificationProjectRNN(ClassificationProject): self.is_training = False except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") - logger.info("Save history") - self._dump_history() + + self.checkpoint_model(epochs) def get_input_list(self, x): @@ -1494,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject): x_input = self.get_input_list(x_batch) yield (x_input, y_train[shuffled_idx[start:start+int(self.batch_size)]], - w_batch*np.array(self.class_weight)[y_batch.astype(int)]) + w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight) @property