From c77d64ff4cf6c07ca75a0800d551e099ead21221 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Thu, 23 Aug 2018 09:40:01 +0200
Subject: [PATCH] put some masking functionality into base class

---
 toolkit.py | 68 +++++++++++++++++++++++++++---------------------------
 1 file changed, 34 insertions(+), 34 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index c3040d8..b31a15a 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -183,6 +183,8 @@ class ClassificationProject(object):
 
     :param loss: loss function name
 
+    :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
+
     """
 
 
@@ -245,7 +247,8 @@ class ClassificationProject(object):
                         tensorboard_opts=None,
                         random_seed=1234,
                         balance_dataset=False,
-                        loss='binary_crossentropy'):
+                        loss='binary_crossentropy',
+                        mask_value=None):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -316,6 +319,7 @@ class ClassificationProject(object):
         self.random_seed = random_seed
         self.balance_dataset = balance_dataset
         self.loss = loss
+        self.mask_value = mask_value
 
         self.s_train = None
         self.b_train = None
@@ -562,12 +566,31 @@ class ClassificationProject(object):
         return self._scaler
 
 
-    def transform(self, x):
-        return self.scaler.transform(x)
+    def _batch_transform(self, x, fn, batch_size):
+        "Transform array in batches, temporarily setting mask_values to nan"
+        transformed = np.empty(x.shape, dtype=x.dtype)
+        for start in range(0, len(x), batch_size):
+            stop = start+batch_size
+            x_batch = np.array(x[start:stop]) # copy
+            x_batch[x_batch == self.mask_value] = np.nan
+            x_batch = fn(x_batch)
+            x_batch[np.isnan(x_batch)] = self.mask_value
+            transformed[start:stop] = x_batch
+        return transformed
+
+
+    def transform(self, x, batch_size=10000):
+        if self.mask_value is not None:
+            return self._batch_transform(x, self.scaler.transform, batch_size)
+        else:
+            return self.scaler.transform(x)
 
 
-    def inverse_transform(self, x):
-        return self.scaler.inverse_transform(x)
+    def inverse_transform(self, x, batch_size=10000):
+        if self.mask_value is not None:
+            return self._batch_transform(x, self.scaler.inverse_transform, batch_size)
+        else:
+            return self.scaler.inverse_transform(x)
 
 
     @property
@@ -603,6 +626,9 @@ class ClassificationProject(object):
 
     def _transform_data(self):
         if not self.data_transformed:
+            if self.mask_value is not None:
+                self.x_train[self.x_train == self.mask_value] = np.nan
+                self.x_test[self.x_test == self.mask_value] = np.nan
             if logger.level <= logging.DEBUG:
                 logger.debug("training data before transformation: {}".format(self.x_train))
                 logger.debug("minimum values: {}".format([np.min(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
@@ -615,6 +641,9 @@ class ClassificationProject(object):
             logger.debug("training data after transformation: {}".format(self.x_train))
             self.x_test = self.scaler.transform(self.x_test)
             self.scaler.copy = orig_copy_setting
+            if self.mask_value is not None:
+                self.x_train[np.isnan(self.x_train)] = self.mask_value
+                self.x_test[np.isnan(self.x_test)] = self.mask_value
             self.data_transformed = True
             logger.info("Training and test data transformed")
 
@@ -1645,14 +1674,6 @@ class ClassificationProjectRNN(ClassificationProject):
             )
 
 
-    def _transform_data(self):
-        self.x_train[self.x_train == self.mask_value] = np.nan
-        self.x_test[self.x_test == self.mask_value] = np.nan
-        super(ClassificationProjectRNN, self)._transform_data()
-        self.x_train[np.isnan(self.x_train)] = self.mask_value
-        self.x_test[np.isnan(self.x_test)] = self.mask_value
-
-
     @property
     def model(self):
         if self._model is None:
@@ -1784,27 +1805,6 @@ class ClassificationProjectRNN(ClassificationProject):
             eval_score("train")
 
 
-    def _batch_transform(self, x, fn, batch_size):
-        "Transform array in batches, temporarily setting mask_values to nan"
-        transformed = np.empty(x.shape, dtype=x.dtype)
-        for start in range(0, len(x), batch_size):
-            stop = start+batch_size
-            x_batch = np.array(x[start:stop]) # copy
-            x_batch[x_batch == self.mask_value] = np.nan
-            x_batch = fn(x_batch)
-            x_batch[np.isnan(x_batch)] = self.mask_value
-            transformed[start:stop] = x_batch
-        return transformed
-
-
-    def transform(self, x, batch_size=10000):
-        return self._batch_transform(x, self.scaler.transform, batch_size)
-
-
-    def inverse_transform(self, x, batch_size=10000):
-        return self._batch_transform(x, self.scaler.inverse_transform, batch_size)
-
-
     def evaluate(self, x_eval, mode=None):
         logger.debug("Evaluate score for {}".format(x_eval))
         x_eval = self.transform(x_eval)
-- 
GitLab