From f52e8ad96613b8f4031b1a6ffa4278d09c1bef5c Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 14 Aug 2018 11:50:44 +0200
Subject: [PATCH] starting to develop yield_batch function for RNN

---
 toolkit.py | 43 ++++++++++++++++++++++++++++++++++---------
 1 file changed, 34 insertions(+), 9 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index befd03e..c6ff50e 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1228,14 +1228,31 @@ class ClassificationProjectRNN(ClassificationProject):
     """
 
     def __init__(self,
-                 recurrent_branches=None,
+                 recurrent_fields=None,
                  mask_value=-999,
                  **kwargs):
-        self.recurrent_branches = recurrent_branches
-        if self.recurrent_branches is None:
-            self.recurrent_branches = []
+        """
+        recurrent_fields example:
+        [["jet1Pt", "jet1Eta", "jet1Phi"],
+         ["jet2Pt", "jet2Eta", "jet2Phi"],
+         ["jet3Pt", "jet3Eta", "jet3Phi"]],
+        [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"],
+         ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]],
+        """
+        self.recurrent_fields = recurrent_fields
+        if self.recurrent_fields is None:
+            self.recurrent_fields = []
+        for i, recurrent_field in enumerate(self.recurrent_fields):
+            self.recurrent_fields[i] = np.array(recurrent_field)
+            if self.recurrent_fields[i].dtype == np.object:
+                raise ValueError(
+                    "Invalid entry for recurrent fields: {} - "
+                    "please ensure that the length for all elements in the list is equal"
+                    .format(recurrent_field)
+                )
         self.mask_value = mask_value
         super(ClassificationProjectRNN, self).__init__()
+        self.flat_fields = [field for field in self.fields if not field in self.recurrent_fields]
 
 
     @property
@@ -1244,12 +1261,20 @@ class ClassificationProjectRNN(ClassificationProject):
 
 
     def yield_batch(self):
+        x_train, y_train, w_train = self.training_data
         while True:
-            permutation = np.random.permutation
-            x_train, y_train, w_train = self.training_data
-            n_training = len(x_train)
-            for batch_start in range(0, n_training, self.batch_size):
-                pass
+            shuffled_idx = np.random.permutation(len(x_train))
+            for start in range(0, len(shuffled_idx), int(self.batch_size)):
+                x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
+                x_flat = x_batch[:,self.flat_fields]
+                x_input = []
+                x_input.append(x_flat)
+                for recurrent_field in self.recurrent_fields:
+                    x_recurrent = x_batch[:,recurrent_field.reshape(-1)].reshape(-1, *recurrent_field.shape)
+                    x_input.append(x_recurrent)
+                yield (x_input,
+                       y_train[shuffled_idx[start:start+int(self.batch_size)]],
+                       w_train[shuffled_idx[start:start+int(self.batch_size)]]*self.balanced_class_weight[class_label])
             # # shuffle the entries for this class label
             # rn_state = np.random.get_state()
             # x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
-- 
GitLab