diff --git a/toolkit.py b/toolkit.py index befd03e9655e5a0c758022a8f0c399503c962061..c6ff50e16b787b1795ca189e00aa9c08928e5ed6 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1228,14 +1228,31 @@ class ClassificationProjectRNN(ClassificationProject): """ def __init__(self, - recurrent_branches=None, + recurrent_fields=None, mask_value=-999, **kwargs): - self.recurrent_branches = recurrent_branches - if self.recurrent_branches is None: - self.recurrent_branches = [] + """ + recurrent_fields example: + [["jet1Pt", "jet1Eta", "jet1Phi"], + ["jet2Pt", "jet2Eta", "jet2Phi"], + ["jet3Pt", "jet3Eta", "jet3Phi"]], + [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"], + ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]], + """ + self.recurrent_fields = recurrent_fields + if self.recurrent_fields is None: + self.recurrent_fields = [] + for i, recurrent_field in enumerate(self.recurrent_fields): + self.recurrent_fields[i] = np.array(recurrent_field) + if self.recurrent_fields[i].dtype == np.object: + raise ValueError( + "Invalid entry for recurrent fields: {} - " + "please ensure that the length for all elements in the list is equal" + .format(recurrent_field) + ) self.mask_value = mask_value super(ClassificationProjectRNN, self).__init__() + self.flat_fields = [field for field in self.fields if not field in self.recurrent_fields] @property @@ -1244,12 +1261,20 @@ class ClassificationProjectRNN(ClassificationProject): def yield_batch(self): + x_train, y_train, w_train = self.training_data while True: - permutation = np.random.permutation - x_train, y_train, w_train = self.training_data - n_training = len(x_train) - for batch_start in range(0, n_training, self.batch_size): - pass + shuffled_idx = np.random.permutation(len(x_train)) + for start in range(0, len(shuffled_idx), int(self.batch_size)): + x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] + x_flat = x_batch[:,self.flat_fields] + x_input = [] + x_input.append(x_flat) + for recurrent_field in self.recurrent_fields: + x_recurrent = x_batch[:,recurrent_field.reshape(-1)].reshape(-1, *recurrent_field.shape) + x_input.append(x_recurrent) + yield (x_input, + y_train[shuffled_idx[start:start+int(self.batch_size)]], + w_train[shuffled_idx[start:start+int(self.batch_size)]]*self.balanced_class_weight[class_label]) # # shuffle the entries for this class label # rn_state = np.random.get_state() # x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])