Skip to content
Snippets Groups Projects
Commit 9994f0ad authored by Nikolai's avatar Nikolai
Browse files

starting rnn wrapper

parent 6409522c
No related branches found
No related tags found
No related merge requests found
......@@ -1221,6 +1221,48 @@ class ClassificationProjectDataFrame(ClassificationProject):
pass
class ClassificationProjectRNN(ClassificationProject):
"""
A little wrapper to use recurrent units for things like jet collections
"""
def __init__(self,
recurrent_branches=None,
mask_value=-999,
**kwargs):
self.recurrent_branches = recurrent_branches
if self.recurrent_branches is None:
self.recurrent_branches = []
self.mask_value = mask_value
super(ClassificationProjectRNN, self).__init__()
@property
def model():
pass
def yield_batch(self):
while True:
permutation = np.random.permutation
x_train, y_train, w_train = self.training_data
n_training = len(x_train)
for batch_start in range(0, n_training, self.batch_size):
pass
# # shuffle the entries for this class label
# rn_state = np.random.get_state()
# x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
# np.random.set_state(rn_state)
# w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
# # yield them batch wise
# for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
# yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
# y_train[y_train==class_label][start:start+int(self.batch_size/2)],
# w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label])
# restart
if __name__ == "__main__":
logging.basicConfig()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment