From 7c1241ff46ab1fa9179cdae515669af39d59d6c1 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 14 Aug 2018 16:32:05 +0200
Subject: [PATCH] evaluate RNN working

---
 toolkit.py | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/toolkit.py b/toolkit.py
index f1e9e4f..57ef3f3 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1498,6 +1498,35 @@ class ClassificationProjectRNN(ClassificationProject):
         return x_val_input, y_val, w_val
 
 
+    def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000):
+        logger.info("Reloading (and re-transforming) unshuffled training data")
+        self.load(reload=True)
+
+        def eval_score(data_name):
+            logger.info("Create/Update scores for {} sample".format(data_name))
+            n_events = len(getattr(self, "x_"+data_name))
+            setattr(self, "scores_"+data_name, np.empty(n_events))
+            for start in range(0, n_events, batch_size):
+                stop = start+batch_size
+                getattr(self, "scores_"+data_name)[start:stop] = self.model.predict(self.get_input_list(getattr(self, "x_"+data_name)[start:stop])).reshape(-1)
+            self._dump_to_hdf5("scores_"+data_name)
+
+        if do_test:
+            eval_score("test")
+        if do_train:
+            eval_score("train")
+
+
+    def evaluate(self, x_eval):
+        logger.debug("Evaluate score for {}".format(x_eval))
+        x_eval = np.array(x_eval) # copy
+        x_eval[x_eval==self.mask_value] = np.nan
+        x_eval = self.scaler.transform(x_eval)
+        x_eval[np.isnan(x_eval)] = self.mask_value
+        logger.debug("Evaluate for transformed array: {}".format(x_eval))
+        return self.model.predict(self.get_input_list(x_eval))
+
+
 if __name__ == "__main__":
 
     logging.basicConfig()
-- 
GitLab