From 7d32696b5a776f894269521f551592de6ae21125 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Mon, 22 Oct 2018 10:38:22 +0200
Subject: [PATCH] got rid of all global shuffling

---
 toolkit.py | 82 ++++++++++++++++++++++++++++--------------------------
 1 file changed, 43 insertions(+), 39 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 6780435..c21f45a 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -188,6 +188,8 @@ class ClassificationProject(object):
 
     :param normalize_weights: normalize the weights to mean 1
 
+    :param shuffle: shuffle training data after (and before first) epoch
+
     """
 
 
@@ -257,7 +259,9 @@ class ClassificationProject(object):
                         loss='binary_crossentropy',
                         mask_value=None,
                         apply_class_weight=True,
-                        normalize_weights=True):
+                        normalize_weights=True,
+                        shuffle=True,
+    ):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -339,6 +343,7 @@ class ClassificationProject(object):
         self.mask_value = mask_value
         self.apply_class_weight = apply_class_weight
         self.normalize_weights = normalize_weights
+        self.shuffle = shuffle
 
         self.s_train = None
         self.b_train = None
@@ -373,7 +378,6 @@ class ClassificationProject(object):
 
         self.data_loaded = False
         self.data_transformed = False
-        self.data_shuffled = False
 
         # track if we are currently training
         self.is_training = False
@@ -475,7 +479,6 @@ class ClassificationProject(object):
             self._dump_to_hdf5(*self.dataset_names_tree)
 
         self.data_loaded = True
-        self.data_shuffled = False
 
 
     def _dump_training_list(self):
@@ -839,23 +842,27 @@ class ClassificationProject(object):
 
     @property
     def validation_data(self):
-        "Validation data"
+        "Validation data for loss evaluation"
         idx = self.train_val_idx[1]
-        return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
+        x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
+        x_val_input = self.get_input_list(x_val)
+        return x_val_input, y_val, w_val
 
 
     @property
     def training_data(self):
         "Training data with validation data split off"
         idx = self.train_val_idx[0]
-        return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
+        x_train, y_train, w_train =  self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
+        x_train_input = self.get_input_list(x_train)
+        return x_train_input, y_train, w_train
 
 
     @property
     def train_val_idx(self):
         if self._train_val_idx is None:
             if self.kfold_splits is not None:
-                kfold = KFold(self.kfold_splits, shuffle=True, random_state=self.shuffle_seed)
+                kfold = KFold(self.kfold_splits, shuffle=self.shuffle, random_state=self.shuffle_seed)
                 for i, train_val_idx in enumerate(kfold.split(self.x_train)):
                     if i == self.kfold_index:
                         self._train_val_idx = train_val_idx
@@ -865,7 +872,10 @@ class ClassificationProject(object):
             else:
                 split_index = int((1-self.validation_split)*len(self.x_train))
                 np.random.seed(self.shuffle_seed)
-                shuffled_idx = np.random.permutation(len(self.x_train))
+                if self.shuffle:
+                    shuffled_idx = np.random.permutation(len(self.x_train))
+                else:
+                    shuffled_idx = np.arange(len(self.x_train))
                 self._train_val_idx = (shuffled_idx[:split_index], shuffled_idx[split_index:])
         return self._train_val_idx
 
@@ -875,17 +885,30 @@ class ClassificationProject(object):
         return int(float(len(self.train_val_idx[0]))/float(self.batch_size))
 
 
+    def get_input_list(self, x):
+        "For the standard Dense models with single input, this does nothing"
+        return x
+
+
     def yield_batch(self):
+        "Batch generator - optionally shuffle the indices after each epoch"
         x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
         train_idx = list(self.train_val_idx[0])
         np.random.seed(self.shuffle_seed+1)
+        logger.info("Generating training batches from {} signal and {} background events"
+                    .format(len(np.where(self.y_train[train_idx]==1)[0]),
+                            len(np.where(self.y_train[train_idx]==0)[0])))
         while True:
-            shuffled_idx = np.random.permutation(train_idx)
+            if self.shuffle:
+                shuffled_idx = np.random.permutation(train_idx)
+            else:
+                shuffled_idx = train_idx
             for start in range(0, len(shuffled_idx), int(self.batch_size)):
                 x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
                 y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
                 w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
-                yield (x_batch, y_batch, w_batch)
+                x_input = self.get_input_list(x_batch)
+                yield (x_input, y_batch, w_batch)
 
 
     def yield_single_class_batch(self, class_label):
@@ -897,7 +920,10 @@ class ClassificationProject(object):
         class_idx = np.where(y_train==class_label)[0]
         while True:
             # shuffle the indices for this class label
-            shuffled_idx = np.random.permutation(class_idx)
+            if self.shuffle:
+                shuffled_idx = np.random.permutation(class_idx)
+            else:
+                shuffled_idx = class_idx
             # yield them batch wise
             for start in range(0, len(shuffled_idx), int(self.batch_size/2)):
                 yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]],
@@ -980,7 +1006,7 @@ class ClassificationProject(object):
 
 
     def evaluate_train_test(self, do_train=True, do_test=True, mode=None):
-        logger.info("Reloading (and re-transforming) unshuffled training data")
+        logger.info("Reloading (and re-transforming) training data")
         self.load(reload=True)
 
         if mode is not None:
@@ -1819,7 +1845,10 @@ class ClassificationProjectRNN(ClassificationProject):
 
 
     def get_input_list(self, x):
-        "Format the input starting from flat ntuple"
+        """
+        Returns a list of 3-dimensional inputs for each
+        recurrent layer and a 2-dimensional one for the normal flat inputs.
+        """
         x_input = []
         for field_idx in self.recurrent_field_idx:
             x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:])
@@ -1830,7 +1859,7 @@ class ClassificationProjectRNN(ClassificationProject):
 
 
     def get_input_flat(self, x):
-        "Transform input back to flat ntuple"
+        "Transform the multiple inputs back to flat ntuple"
         nevent = x[0].shape[0]
         x_flat = np.empty((nevent, len(self.fields)), dtype=np.float)
         # recurrent fields
@@ -1845,31 +1874,6 @@ class ClassificationProjectRNN(ClassificationProject):
         return x_flat
 
 
-    def yield_batch(self):
-        x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
-        train_idx = list(self.train_val_idx[0])
-        np.random.seed(self.shuffle_seed+1)
-        logger.info("Generating training batches from {} signal and {} background events"
-                    .format(len(np.where(self.y_train[train_idx]==1)[0]),
-                            len(np.where(self.y_train[train_idx]==0)[0])))
-        while True:
-            shuffled_idx = np.random.permutation(train_idx)
-            for start in range(0, len(shuffled_idx), int(self.batch_size)):
-                x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
-                y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
-                w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
-                x_input = self.get_input_list(x_batch)
-                yield (x_input, y_batch, w_batch)
-
-
-    @property
-    def validation_data(self):
-        "class weighted validation data. Attention: Shuffle training data before using this!"
-        x_val, y_val, w_val = super(ClassificationProjectRNN, self).validation_data
-        x_val_input = self.get_input_list(x_val)
-        return x_val_input, y_val, w_val
-
-
     def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None):
         logger.info("Reloading (and re-transforming) unshuffled training data")
         self.load(reload=True)
-- 
GitLab