From 7d32696b5a776f894269521f551592de6ae21125 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Mon, 22 Oct 2018 10:38:22 +0200 Subject: [PATCH] got rid of all global shuffling --- toolkit.py | 82 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/toolkit.py b/toolkit.py index 6780435..c21f45a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -188,6 +188,8 @@ class ClassificationProject(object): :param normalize_weights: normalize the weights to mean 1 + :param shuffle: shuffle training data after (and before first) epoch + """ @@ -257,7 +259,9 @@ class ClassificationProject(object): loss='binary_crossentropy', mask_value=None, apply_class_weight=True, - normalize_weights=True): + normalize_weights=True, + shuffle=True, + ): self.name = name self.signal_trees = signal_trees @@ -339,6 +343,7 @@ class ClassificationProject(object): self.mask_value = mask_value self.apply_class_weight = apply_class_weight self.normalize_weights = normalize_weights + self.shuffle = shuffle self.s_train = None self.b_train = None @@ -373,7 +378,6 @@ class ClassificationProject(object): self.data_loaded = False self.data_transformed = False - self.data_shuffled = False # track if we are currently training self.is_training = False @@ -475,7 +479,6 @@ class ClassificationProject(object): self._dump_to_hdf5(*self.dataset_names_tree) self.data_loaded = True - self.data_shuffled = False def _dump_training_list(self): @@ -839,23 +842,27 @@ class ClassificationProject(object): @property def validation_data(self): - "Validation data" + "Validation data for loss evaluation" idx = self.train_val_idx[1] - return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] + x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] + x_val_input = self.get_input_list(x_val) + return x_val_input, y_val, w_val @property def training_data(self): "Training data with validation data split off" idx = self.train_val_idx[0] - return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] + x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] + x_train_input = self.get_input_list(x_train) + return x_train_input, y_train, w_train @property def train_val_idx(self): if self._train_val_idx is None: if self.kfold_splits is not None: - kfold = KFold(self.kfold_splits, shuffle=True, random_state=self.shuffle_seed) + kfold = KFold(self.kfold_splits, shuffle=self.shuffle, random_state=self.shuffle_seed) for i, train_val_idx in enumerate(kfold.split(self.x_train)): if i == self.kfold_index: self._train_val_idx = train_val_idx @@ -865,7 +872,10 @@ class ClassificationProject(object): else: split_index = int((1-self.validation_split)*len(self.x_train)) np.random.seed(self.shuffle_seed) - shuffled_idx = np.random.permutation(len(self.x_train)) + if self.shuffle: + shuffled_idx = np.random.permutation(len(self.x_train)) + else: + shuffled_idx = np.arange(len(self.x_train)) self._train_val_idx = (shuffled_idx[:split_index], shuffled_idx[split_index:]) return self._train_val_idx @@ -875,17 +885,30 @@ class ClassificationProject(object): return int(float(len(self.train_val_idx[0]))/float(self.batch_size)) + def get_input_list(self, x): + "For the standard Dense models with single input, this does nothing" + return x + + def yield_batch(self): + "Batch generator - optionally shuffle the indices after each epoch" x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot train_idx = list(self.train_val_idx[0]) np.random.seed(self.shuffle_seed+1) + logger.info("Generating training batches from {} signal and {} background events" + .format(len(np.where(self.y_train[train_idx]==1)[0]), + len(np.where(self.y_train[train_idx]==0)[0]))) while True: - shuffled_idx = np.random.permutation(train_idx) + if self.shuffle: + shuffled_idx = np.random.permutation(train_idx) + else: + shuffled_idx = train_idx for start in range(0, len(shuffled_idx), int(self.batch_size)): x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]] w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] - yield (x_batch, y_batch, w_batch) + x_input = self.get_input_list(x_batch) + yield (x_input, y_batch, w_batch) def yield_single_class_batch(self, class_label): @@ -897,7 +920,10 @@ class ClassificationProject(object): class_idx = np.where(y_train==class_label)[0] while True: # shuffle the indices for this class label - shuffled_idx = np.random.permutation(class_idx) + if self.shuffle: + shuffled_idx = np.random.permutation(class_idx) + else: + shuffled_idx = class_idx # yield them batch wise for start in range(0, len(shuffled_idx), int(self.batch_size/2)): yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]], @@ -980,7 +1006,7 @@ class ClassificationProject(object): def evaluate_train_test(self, do_train=True, do_test=True, mode=None): - logger.info("Reloading (and re-transforming) unshuffled training data") + logger.info("Reloading (and re-transforming) training data") self.load(reload=True) if mode is not None: @@ -1819,7 +1845,10 @@ class ClassificationProjectRNN(ClassificationProject): def get_input_list(self, x): - "Format the input starting from flat ntuple" + """ + Returns a list of 3-dimensional inputs for each + recurrent layer and a 2-dimensional one for the normal flat inputs. + """ x_input = [] for field_idx in self.recurrent_field_idx: x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:]) @@ -1830,7 +1859,7 @@ class ClassificationProjectRNN(ClassificationProject): def get_input_flat(self, x): - "Transform input back to flat ntuple" + "Transform the multiple inputs back to flat ntuple" nevent = x[0].shape[0] x_flat = np.empty((nevent, len(self.fields)), dtype=np.float) # recurrent fields @@ -1845,31 +1874,6 @@ class ClassificationProjectRNN(ClassificationProject): return x_flat - def yield_batch(self): - x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot - train_idx = list(self.train_val_idx[0]) - np.random.seed(self.shuffle_seed+1) - logger.info("Generating training batches from {} signal and {} background events" - .format(len(np.where(self.y_train[train_idx]==1)[0]), - len(np.where(self.y_train[train_idx]==0)[0]))) - while True: - shuffled_idx = np.random.permutation(train_idx) - for start in range(0, len(shuffled_idx), int(self.batch_size)): - x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] - y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]] - w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] - x_input = self.get_input_list(x_batch) - yield (x_input, y_batch, w_batch) - - - @property - def validation_data(self): - "class weighted validation data. Attention: Shuffle training data before using this!" - x_val, y_val, w_val = super(ClassificationProjectRNN, self).validation_data - x_val_input = self.get_input_list(x_val) - return x_val_input, y_val, w_val - - def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): logger.info("Reloading (and re-transforming) unshuffled training data") self.load(reload=True) -- GitLab