From 93b9b214dd95858bddec508e991847a3cab00333 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Mon, 10 Sep 2018 11:12:54 +0200
Subject: [PATCH] attempt to modify training to always use fit_generator +
 option for KFold cross validation

---
 toolkit.py | 114 +++++++++++++++++++++++++++++------------------------
 1 file changed, 62 insertions(+), 52 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index f108d1d..51efd3f 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -33,6 +33,7 @@ from sklearn.preprocessing import StandardScaler, RobustScaler
 from sklearn.externals import joblib
 from sklearn.metrics import roc_curve, auc
 from sklearn.utils.extmath import stable_cumsum
+from sklearn.model_selection import KFold
 from keras.models import Sequential, Model, model_from_json
 from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN
 from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
@@ -125,6 +126,10 @@ class ClassificationProject(object):
 
     :param validation_split: split off this fraction of training events for loss evaluation
 
+    :param kfold_splits: if given, split into this number of of subsets to perform KFold cross validation
+
+    :param kfold_index: index of the subset to leave out for kfold validation
+
     :param activation_function: activation function in the hidden layers
 
     :param activation_function_output: activation function in the output layer
@@ -224,6 +229,8 @@ class ClassificationProject(object):
                         dropout_input=None,
                         batch_size=128,
                         validation_split=0.33,
+                        kfold_splits=None,
+                        kfold_index=0,
                         activation_function='relu',
                         activation_function_output='sigmoid',
                         scaler_type="WeightedRobustScaler",
@@ -284,6 +291,8 @@ class ClassificationProject(object):
         self.dropout_input = dropout_input
         self.batch_size = batch_size
         self.validation_split = validation_split
+        self.kfold_splits = kfold_splits
+        self.kfold_index = kfold_index
         self.activation_function = activation_function
         self.activation_function_output = activation_function_output
         self.scaler_type = scaler_type
@@ -347,6 +356,7 @@ class ClassificationProject(object):
         self._model = None
         self._history = None
         self._callbacks_list = []
+        self._train_val_idx = None
 
         # track the number of epochs this model has been trained
         self.total_epochs = 0
@@ -842,20 +852,53 @@ class ClassificationProject(object):
 
     @property
     def validation_data(self):
-        "Validation data. Attention: Shuffle training data before using this!"
-        if not self.data_shuffled:
-            raise ValueError("Training data isn't shuffled, can't split of validation data")
-        split_index = int((1-self.validation_split)*len(self.x_train))
-        return self.x_train[split_index:], self.y_train[split_index:], self.w_train_tot[split_index:]
+        "Validation data"
+        idx = self.train_val_idx[1]
+        return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
 
 
     @property
     def training_data(self):
-        "Training data with validation data split off. Attention: Shuffle training data before using this!"
-        if not self.data_shuffled:
-            raise ValueError("Training data isn't shuffled, can't split of validation data")
-        split_index = int((1-self.validation_split)*len(self.x_train))
-        return self.x_train[:split_index], self.y_train[:split_index], self.w_train_tot[:split_index]
+        "Training data with validation data split off"
+        idx = self.train_val_idx[0]
+        return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
+
+
+    @property
+    def train_val_idx(self):
+        if self._train_val_idx is None:
+            if self.kfold_splits is not None:
+                kfold = KFold(self.kfold_splits, shuffle=True, random_state=self.shuffle_seed)
+                for i, train_val_idx in kfold.split(self.x_train):
+                    if i == self.kfold_index:
+                        self._train_val_idx = train_val_idx
+                        break
+                else:
+                    raise IndexError("Index {} out of range for kfold (requested {} splits)".format(self.kfold_index, self.kfold_splits))
+            else:
+                split_index = int((1-self.validation_split)*len(self.x_train))
+                np.random.seed(self.shuffle_seed)
+                shuffled_idx = np.random.permutation(len(self.x_train))
+                self._train_val_idx = (shuffled_idx[:split_index], shuffled_idx[split_index:])
+        return self._train_val_idx
+
+
+    @property
+    def steps_per_epoch(self):
+        return int(float(len(self.train_val_idx[0]))/float(self.batch_size))
+
+
+    def yield_batch(self):
+        x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
+        train_idx = list(self.train_val_idx[0])
+        np.random.seed(self.shuffle_seed+1)
+        while True:
+            shuffled_idx = np.random.permutation(train_idx)
+            for start in range(0, len(shuffled_idx), int(self.batch_size)):
+                x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
+                y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
+                w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
+                yield (x_batch, y_batch, w_batch)
 
 
     def yield_single_class_batch(self, class_label):
@@ -894,8 +937,6 @@ class ClassificationProject(object):
 
         self.load()
 
-        self.shuffle_training_data()
-
         for branch_index, branch in enumerate(self.fields):
             self.plot_input(branch_index)
 
@@ -905,18 +946,11 @@ class ClassificationProject(object):
         if not self.balance_dataset:
             try:
                 self.is_training = True
-                np.random.seed(self.shuffle_seed+1) # since we use keras shuffling here
-                self.model.fit(self.x_train,
-                               # the reshape might be unnescessary here
-                               self.y_train.reshape(-1, 1),
-                               epochs=epochs,
-                               validation_split=self.validation_split,
-                               # we have to multiply by class weight since keras ignores class weight if sample weight is given
-                               # see https://github.com/keras-team/keras/issues/497
-                               sample_weight=self.w_train_tot,
-                               shuffle=True,
-                               batch_size=self.batch_size,
-                               callbacks=self.callbacks_list)
+                self.model.fit_generator(self.yield_batch(),
+                                         steps_per_epoch=self.steps_per_epoch,
+                                         epochs=epochs,
+                                         validation_data=self.validation_data,
+                                         callbacks=self.callbacks_list)
                 self.is_training = False
             except KeyboardInterrupt:
                 logger.info("Interrupt training - continue with rest")
@@ -1761,32 +1795,6 @@ class ClassificationProjectRNN(ClassificationProject):
         return self._model
 
 
-    def train(self, epochs=10):
-        self.load()
-
-        self.shuffle_training_data()
-
-        for branch_index, branch in enumerate(self.fields):
-            self.plot_input(branch_index)
-
-        self.total_epochs = self._read_info("epochs", 0)
-
-        try:
-            self.is_training = True
-            logger.info("Training on batches for RNN")
-            # note: the batches have class_weight already applied
-            self.model.fit_generator(self.yield_batch(),
-                                     steps_per_epoch=int(len(self.training_data[0])/self.batch_size),
-                                     epochs=epochs,
-                                     validation_data=self.validation_data,
-                                     callbacks=self.callbacks_list)
-            self.is_training = False
-        except KeyboardInterrupt:
-            logger.info("Interrupt training - continue with rest")
-
-        self.checkpoint_model()
-
-
     def clean_mask(self, x):
         """
         Mask recurrent fields such that once a masked value occurs,
@@ -1846,9 +1854,11 @@ class ClassificationProjectRNN(ClassificationProject):
 
 
     def yield_batch(self):
-        x_train, y_train, w_train = self.training_data
+        x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
+        train_idx = list(self.train_val_idx[0])
+        np.random.seed(self.shuffle_seed+1)
         while True:
-            shuffled_idx = np.random.permutation(len(x_train))
+            shuffled_idx = np.random.permutation(train_idx)
             for start in range(0, len(shuffled_idx), int(self.batch_size)):
                 x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
                 y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
-- 
GitLab