diff --git a/toolkit.py b/toolkit.py index f108d1da7bedf4707e3ca16cefcbe5d75bf39023..51efd3f04d6f0c0078694f168fa4fcffe13cd0ad 100755 --- a/toolkit.py +++ b/toolkit.py @@ -33,6 +33,7 @@ from sklearn.preprocessing import StandardScaler, RobustScaler from sklearn.externals import joblib from sklearn.metrics import roc_curve, auc from sklearn.utils.extmath import stable_cumsum +from sklearn.model_selection import KFold from keras.models import Sequential, Model, model_from_json from keras.layers import Dense, Dropout, Input, Masking, GRU, LSTM, concatenate, SimpleRNN from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard @@ -125,6 +126,10 @@ class ClassificationProject(object): :param validation_split: split off this fraction of training events for loss evaluation + :param kfold_splits: if given, split into this number of of subsets to perform KFold cross validation + + :param kfold_index: index of the subset to leave out for kfold validation + :param activation_function: activation function in the hidden layers :param activation_function_output: activation function in the output layer @@ -224,6 +229,8 @@ class ClassificationProject(object): dropout_input=None, batch_size=128, validation_split=0.33, + kfold_splits=None, + kfold_index=0, activation_function='relu', activation_function_output='sigmoid', scaler_type="WeightedRobustScaler", @@ -284,6 +291,8 @@ class ClassificationProject(object): self.dropout_input = dropout_input self.batch_size = batch_size self.validation_split = validation_split + self.kfold_splits = kfold_splits + self.kfold_index = kfold_index self.activation_function = activation_function self.activation_function_output = activation_function_output self.scaler_type = scaler_type @@ -347,6 +356,7 @@ class ClassificationProject(object): self._model = None self._history = None self._callbacks_list = [] + self._train_val_idx = None # track the number of epochs this model has been trained self.total_epochs = 0 @@ -842,20 +852,53 @@ class ClassificationProject(object): @property def validation_data(self): - "Validation data. Attention: Shuffle training data before using this!" - if not self.data_shuffled: - raise ValueError("Training data isn't shuffled, can't split of validation data") - split_index = int((1-self.validation_split)*len(self.x_train)) - return self.x_train[split_index:], self.y_train[split_index:], self.w_train_tot[split_index:] + "Validation data" + idx = self.train_val_idx[1] + return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] @property def training_data(self): - "Training data with validation data split off. Attention: Shuffle training data before using this!" - if not self.data_shuffled: - raise ValueError("Training data isn't shuffled, can't split of validation data") - split_index = int((1-self.validation_split)*len(self.x_train)) - return self.x_train[:split_index], self.y_train[:split_index], self.w_train_tot[:split_index] + "Training data with validation data split off" + idx = self.train_val_idx[0] + return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] + + + @property + def train_val_idx(self): + if self._train_val_idx is None: + if self.kfold_splits is not None: + kfold = KFold(self.kfold_splits, shuffle=True, random_state=self.shuffle_seed) + for i, train_val_idx in kfold.split(self.x_train): + if i == self.kfold_index: + self._train_val_idx = train_val_idx + break + else: + raise IndexError("Index {} out of range for kfold (requested {} splits)".format(self.kfold_index, self.kfold_splits)) + else: + split_index = int((1-self.validation_split)*len(self.x_train)) + np.random.seed(self.shuffle_seed) + shuffled_idx = np.random.permutation(len(self.x_train)) + self._train_val_idx = (shuffled_idx[:split_index], shuffled_idx[split_index:]) + return self._train_val_idx + + + @property + def steps_per_epoch(self): + return int(float(len(self.train_val_idx[0]))/float(self.batch_size)) + + + def yield_batch(self): + x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot + train_idx = list(self.train_val_idx[0]) + np.random.seed(self.shuffle_seed+1) + while True: + shuffled_idx = np.random.permutation(train_idx) + for start in range(0, len(shuffled_idx), int(self.batch_size)): + x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] + y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]] + w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] + yield (x_batch, y_batch, w_batch) def yield_single_class_batch(self, class_label): @@ -894,8 +937,6 @@ class ClassificationProject(object): self.load() - self.shuffle_training_data() - for branch_index, branch in enumerate(self.fields): self.plot_input(branch_index) @@ -905,18 +946,11 @@ class ClassificationProject(object): if not self.balance_dataset: try: self.is_training = True - np.random.seed(self.shuffle_seed+1) # since we use keras shuffling here - self.model.fit(self.x_train, - # the reshape might be unnescessary here - self.y_train.reshape(-1, 1), - epochs=epochs, - validation_split=self.validation_split, - # we have to multiply by class weight since keras ignores class weight if sample weight is given - # see https://github.com/keras-team/keras/issues/497 - sample_weight=self.w_train_tot, - shuffle=True, - batch_size=self.batch_size, - callbacks=self.callbacks_list) + self.model.fit_generator(self.yield_batch(), + steps_per_epoch=self.steps_per_epoch, + epochs=epochs, + validation_data=self.validation_data, + callbacks=self.callbacks_list) self.is_training = False except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") @@ -1761,32 +1795,6 @@ class ClassificationProjectRNN(ClassificationProject): return self._model - def train(self, epochs=10): - self.load() - - self.shuffle_training_data() - - for branch_index, branch in enumerate(self.fields): - self.plot_input(branch_index) - - self.total_epochs = self._read_info("epochs", 0) - - try: - self.is_training = True - logger.info("Training on batches for RNN") - # note: the batches have class_weight already applied - self.model.fit_generator(self.yield_batch(), - steps_per_epoch=int(len(self.training_data[0])/self.batch_size), - epochs=epochs, - validation_data=self.validation_data, - callbacks=self.callbacks_list) - self.is_training = False - except KeyboardInterrupt: - logger.info("Interrupt training - continue with rest") - - self.checkpoint_model() - - def clean_mask(self, x): """ Mask recurrent fields such that once a masked value occurs, @@ -1846,9 +1854,11 @@ class ClassificationProjectRNN(ClassificationProject): def yield_batch(self): - x_train, y_train, w_train = self.training_data + x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot + train_idx = list(self.train_val_idx[0]) + np.random.seed(self.shuffle_seed+1) while True: - shuffled_idx = np.random.permutation(len(x_train)) + shuffled_idx = np.random.permutation(train_idx) for start in range(0, len(shuffled_idx), int(self.batch_size)): x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]