Skip to content
Snippets Groups Projects
Commit 7d32696b authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

got rid of all global shuffling

parent 649d4d96
No related branches found
No related tags found
No related merge requests found
...@@ -188,6 +188,8 @@ class ClassificationProject(object): ...@@ -188,6 +188,8 @@ class ClassificationProject(object):
:param normalize_weights: normalize the weights to mean 1 :param normalize_weights: normalize the weights to mean 1
:param shuffle: shuffle training data after (and before first) epoch
""" """
...@@ -257,7 +259,9 @@ class ClassificationProject(object): ...@@ -257,7 +259,9 @@ class ClassificationProject(object):
loss='binary_crossentropy', loss='binary_crossentropy',
mask_value=None, mask_value=None,
apply_class_weight=True, apply_class_weight=True,
normalize_weights=True): normalize_weights=True,
shuffle=True,
):
self.name = name self.name = name
self.signal_trees = signal_trees self.signal_trees = signal_trees
...@@ -339,6 +343,7 @@ class ClassificationProject(object): ...@@ -339,6 +343,7 @@ class ClassificationProject(object):
self.mask_value = mask_value self.mask_value = mask_value
self.apply_class_weight = apply_class_weight self.apply_class_weight = apply_class_weight
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
self.shuffle = shuffle
self.s_train = None self.s_train = None
self.b_train = None self.b_train = None
...@@ -373,7 +378,6 @@ class ClassificationProject(object): ...@@ -373,7 +378,6 @@ class ClassificationProject(object):
self.data_loaded = False self.data_loaded = False
self.data_transformed = False self.data_transformed = False
self.data_shuffled = False
# track if we are currently training # track if we are currently training
self.is_training = False self.is_training = False
...@@ -475,7 +479,6 @@ class ClassificationProject(object): ...@@ -475,7 +479,6 @@ class ClassificationProject(object):
self._dump_to_hdf5(*self.dataset_names_tree) self._dump_to_hdf5(*self.dataset_names_tree)
self.data_loaded = True self.data_loaded = True
self.data_shuffled = False
def _dump_training_list(self): def _dump_training_list(self):
...@@ -839,23 +842,27 @@ class ClassificationProject(object): ...@@ -839,23 +842,27 @@ class ClassificationProject(object):
@property @property
def validation_data(self): def validation_data(self):
"Validation data" "Validation data for loss evaluation"
idx = self.train_val_idx[1] idx = self.train_val_idx[1]
return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_val_input = self.get_input_list(x_val)
return x_val_input, y_val, w_val
@property @property
def training_data(self): def training_data(self):
"Training data with validation data split off" "Training data with validation data split off"
idx = self.train_val_idx[0] idx = self.train_val_idx[0]
return self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_train_input = self.get_input_list(x_train)
return x_train_input, y_train, w_train
@property @property
def train_val_idx(self): def train_val_idx(self):
if self._train_val_idx is None: if self._train_val_idx is None:
if self.kfold_splits is not None: if self.kfold_splits is not None:
kfold = KFold(self.kfold_splits, shuffle=True, random_state=self.shuffle_seed) kfold = KFold(self.kfold_splits, shuffle=self.shuffle, random_state=self.shuffle_seed)
for i, train_val_idx in enumerate(kfold.split(self.x_train)): for i, train_val_idx in enumerate(kfold.split(self.x_train)):
if i == self.kfold_index: if i == self.kfold_index:
self._train_val_idx = train_val_idx self._train_val_idx = train_val_idx
...@@ -865,7 +872,10 @@ class ClassificationProject(object): ...@@ -865,7 +872,10 @@ class ClassificationProject(object):
else: else:
split_index = int((1-self.validation_split)*len(self.x_train)) split_index = int((1-self.validation_split)*len(self.x_train))
np.random.seed(self.shuffle_seed) np.random.seed(self.shuffle_seed)
shuffled_idx = np.random.permutation(len(self.x_train)) if self.shuffle:
shuffled_idx = np.random.permutation(len(self.x_train))
else:
shuffled_idx = np.arange(len(self.x_train))
self._train_val_idx = (shuffled_idx[:split_index], shuffled_idx[split_index:]) self._train_val_idx = (shuffled_idx[:split_index], shuffled_idx[split_index:])
return self._train_val_idx return self._train_val_idx
...@@ -875,17 +885,30 @@ class ClassificationProject(object): ...@@ -875,17 +885,30 @@ class ClassificationProject(object):
return int(float(len(self.train_val_idx[0]))/float(self.batch_size)) return int(float(len(self.train_val_idx[0]))/float(self.batch_size))
def get_input_list(self, x):
"For the standard Dense models with single input, this does nothing"
return x
def yield_batch(self): def yield_batch(self):
"Batch generator - optionally shuffle the indices after each epoch"
x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
train_idx = list(self.train_val_idx[0]) train_idx = list(self.train_val_idx[0])
np.random.seed(self.shuffle_seed+1) np.random.seed(self.shuffle_seed+1)
logger.info("Generating training batches from {} signal and {} background events"
.format(len(np.where(self.y_train[train_idx]==1)[0]),
len(np.where(self.y_train[train_idx]==0)[0])))
while True: while True:
shuffled_idx = np.random.permutation(train_idx) if self.shuffle:
shuffled_idx = np.random.permutation(train_idx)
else:
shuffled_idx = train_idx
for start in range(0, len(shuffled_idx), int(self.batch_size)): for start in range(0, len(shuffled_idx), int(self.batch_size)):
x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]] y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
yield (x_batch, y_batch, w_batch) x_input = self.get_input_list(x_batch)
yield (x_input, y_batch, w_batch)
def yield_single_class_batch(self, class_label): def yield_single_class_batch(self, class_label):
...@@ -897,7 +920,10 @@ class ClassificationProject(object): ...@@ -897,7 +920,10 @@ class ClassificationProject(object):
class_idx = np.where(y_train==class_label)[0] class_idx = np.where(y_train==class_label)[0]
while True: while True:
# shuffle the indices for this class label # shuffle the indices for this class label
shuffled_idx = np.random.permutation(class_idx) if self.shuffle:
shuffled_idx = np.random.permutation(class_idx)
else:
shuffled_idx = class_idx
# yield them batch wise # yield them batch wise
for start in range(0, len(shuffled_idx), int(self.batch_size/2)): for start in range(0, len(shuffled_idx), int(self.batch_size/2)):
yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]], yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]],
...@@ -980,7 +1006,7 @@ class ClassificationProject(object): ...@@ -980,7 +1006,7 @@ class ClassificationProject(object):
def evaluate_train_test(self, do_train=True, do_test=True, mode=None): def evaluate_train_test(self, do_train=True, do_test=True, mode=None):
logger.info("Reloading (and re-transforming) unshuffled training data") logger.info("Reloading (and re-transforming) training data")
self.load(reload=True) self.load(reload=True)
if mode is not None: if mode is not None:
...@@ -1819,7 +1845,10 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1819,7 +1845,10 @@ class ClassificationProjectRNN(ClassificationProject):
def get_input_list(self, x): def get_input_list(self, x):
"Format the input starting from flat ntuple" """
Returns a list of 3-dimensional inputs for each
recurrent layer and a 2-dimensional one for the normal flat inputs.
"""
x_input = [] x_input = []
for field_idx in self.recurrent_field_idx: for field_idx in self.recurrent_field_idx:
x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:]) x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:])
...@@ -1830,7 +1859,7 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1830,7 +1859,7 @@ class ClassificationProjectRNN(ClassificationProject):
def get_input_flat(self, x): def get_input_flat(self, x):
"Transform input back to flat ntuple" "Transform the multiple inputs back to flat ntuple"
nevent = x[0].shape[0] nevent = x[0].shape[0]
x_flat = np.empty((nevent, len(self.fields)), dtype=np.float) x_flat = np.empty((nevent, len(self.fields)), dtype=np.float)
# recurrent fields # recurrent fields
...@@ -1845,31 +1874,6 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1845,31 +1874,6 @@ class ClassificationProjectRNN(ClassificationProject):
return x_flat return x_flat
def yield_batch(self):
x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
train_idx = list(self.train_val_idx[0])
np.random.seed(self.shuffle_seed+1)
logger.info("Generating training batches from {} signal and {} background events"
.format(len(np.where(self.y_train[train_idx]==1)[0]),
len(np.where(self.y_train[train_idx]==0)[0])))
while True:
shuffled_idx = np.random.permutation(train_idx)
for start in range(0, len(shuffled_idx), int(self.batch_size)):
x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
x_input = self.get_input_list(x_batch)
yield (x_input, y_batch, w_batch)
@property
def validation_data(self):
"class weighted validation data. Attention: Shuffle training data before using this!"
x_val, y_val, w_val = super(ClassificationProjectRNN, self).validation_data
x_val_input = self.get_input_list(x_val)
return x_val_input, y_val, w_val
def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None):
logger.info("Reloading (and re-transforming) unshuffled training data") logger.info("Reloading (and re-transforming) unshuffled training data")
self.load(reload=True) self.load(reload=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment