From 9994f0ad22c64d6214830f83dff6b061b56c6ed2 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 14 Aug 2018 09:37:17 +0200
Subject: [PATCH] starting rnn wrapper

---
 toolkit.py | 42 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 42 insertions(+)

diff --git a/toolkit.py b/toolkit.py
index 37bb1b2..befd03e 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1221,6 +1221,48 @@ class ClassificationProjectDataFrame(ClassificationProject):
         pass
 
 
+class ClassificationProjectRNN(ClassificationProject):
+
+    """
+    A little wrapper to use recurrent units for things like jet collections
+    """
+
+    def __init__(self,
+                 recurrent_branches=None,
+                 mask_value=-999,
+                 **kwargs):
+        self.recurrent_branches = recurrent_branches
+        if self.recurrent_branches is None:
+            self.recurrent_branches = []
+        self.mask_value = mask_value
+        super(ClassificationProjectRNN, self).__init__()
+
+
+    @property
+    def model():
+        pass
+
+
+    def yield_batch(self):
+        while True:
+            permutation = np.random.permutation
+            x_train, y_train, w_train = self.training_data
+            n_training = len(x_train)
+            for batch_start in range(0, n_training, self.batch_size):
+                pass
+            # # shuffle the entries for this class label
+            # rn_state = np.random.get_state()
+            # x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
+            # np.random.set_state(rn_state)
+            # w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
+            # # yield them batch wise
+            # for start in range(0, len(x_train[y_train==class_label]), int(self.batch_size/2)):
+            #     yield (x_train[y_train==class_label][start:start+int(self.batch_size/2)],
+            #            y_train[y_train==class_label][start:start+int(self.batch_size/2)],
+            #            w_train[y_train==class_label][start:start+int(self.batch_size/2)]*self.balanced_class_weight[class_label])
+            # restart
+
+
 if __name__ == "__main__":
 
     logging.basicConfig()
-- 
GitLab