Skip to content
Snippets Groups Projects
Commit f52e8ad9 authored by Nikolai's avatar Nikolai
Browse files

starting to develop yield_batch function for RNN

parent 9994f0ad
No related branches found
No related tags found
No related merge requests found
...@@ -1228,14 +1228,31 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1228,14 +1228,31 @@ class ClassificationProjectRNN(ClassificationProject):
""" """
def __init__(self, def __init__(self,
recurrent_branches=None, recurrent_fields=None,
mask_value=-999, mask_value=-999,
**kwargs): **kwargs):
self.recurrent_branches = recurrent_branches """
if self.recurrent_branches is None: recurrent_fields example:
self.recurrent_branches = [] [["jet1Pt", "jet1Eta", "jet1Phi"],
["jet2Pt", "jet2Eta", "jet2Phi"],
["jet3Pt", "jet3Eta", "jet3Phi"]],
[["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"],
["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]],
"""
self.recurrent_fields = recurrent_fields
if self.recurrent_fields is None:
self.recurrent_fields = []
for i, recurrent_field in enumerate(self.recurrent_fields):
self.recurrent_fields[i] = np.array(recurrent_field)
if self.recurrent_fields[i].dtype == np.object:
raise ValueError(
"Invalid entry for recurrent fields: {} - "
"please ensure that the length for all elements in the list is equal"
.format(recurrent_field)
)
self.mask_value = mask_value self.mask_value = mask_value
super(ClassificationProjectRNN, self).__init__() super(ClassificationProjectRNN, self).__init__()
self.flat_fields = [field for field in self.fields if not field in self.recurrent_fields]
@property @property
...@@ -1244,12 +1261,20 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1244,12 +1261,20 @@ class ClassificationProjectRNN(ClassificationProject):
def yield_batch(self): def yield_batch(self):
x_train, y_train, w_train = self.training_data
while True: while True:
permutation = np.random.permutation shuffled_idx = np.random.permutation(len(x_train))
x_train, y_train, w_train = self.training_data for start in range(0, len(shuffled_idx), int(self.batch_size)):
n_training = len(x_train) x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
for batch_start in range(0, n_training, self.batch_size): x_flat = x_batch[:,self.flat_fields]
pass x_input = []
x_input.append(x_flat)
for recurrent_field in self.recurrent_fields:
x_recurrent = x_batch[:,recurrent_field.reshape(-1)].reshape(-1, *recurrent_field.shape)
x_input.append(x_recurrent)
yield (x_input,
y_train[shuffled_idx[start:start+int(self.batch_size)]],
w_train[shuffled_idx[start:start+int(self.batch_size)]]*self.balanced_class_weight[class_label])
# # shuffle the entries for this class label # # shuffle the entries for this class label
# rn_state = np.random.get_state() # rn_state = np.random.get_state()
# x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label]) # x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment