diff --git a/toolkit.py b/toolkit.py index 37bb1b20d406b2d7f953810ef0e1c6f725499062..befd03e9655e5a0c758022a8f0c399503c962061 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1221,6 +1221,48 @@ class ClassificationProjectDataFrame(ClassificationProject): pass +class ClassificationProjectRNN(ClassificationProject): + + """ + A little wrapper to use recurrent units for things like jet collections + """ + + def __init__(self, + recurrent_branches=None, + mask_value=-999, + **kwargs): + self.recurrent_branches = recurrent_branches + if self.recurrent_branches is None: + self.recurrent_branches = [] + self.mask_value = mask_value + super(ClassificationProjectRNN, self).__init__() + + + @property + def model(): + pass + + + def yield_batch(self): + while True: + permutation = np.random.permutation + x_train, y_train, w_train = self.training_data + n_training = len(x_train) + for batch_start in range(0, n_training, self.batch_size): + pass + # # shuffle the entries for this class label + # rn_state = np.random.get_state() + # x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label]) + # np.random.set_state(rn_state) + # w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label]) + # # yield them batch wise + # for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)): + # yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)], + # y_train[y_train==class_label][start:start+int(self.batch_size/2)], + # w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label]) + # restart + + if __name__ == "__main__": logging.basicConfig()