From 0e1202f542e1107f8bd2bf98e63a2f4a9d8f7a6b Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Fri, 10 Aug 2018 15:42:39 +0200
Subject: [PATCH] Memory improvements

* stop option for reading a maximum number of events for train/test
* inplace scaling
---
 toolkit.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index 875cf44..19fa5ec 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -133,6 +133,10 @@ class ClassificationProject(object):
 
     :param step_bkg: step size when selecting background training events (e.g. 2 means take every second event)
 
+    :param stop_train: stop after this number of events for reading in training events
+
+    :param stop_test: stop after this number of events for reading in test events
+
     :param optimizer: name of optimizer class in keras.optimizers
 
     :param optimizer_opts: dictionary of options for the optimizer
@@ -213,6 +217,8 @@ class ClassificationProject(object):
                         scaler_type="WeightedRobustScaler",
                         step_signal=2,
                         step_bkg=2,
+                        stop_train=None,
+                        stop_test=None,
                         optimizer="SGD",
                         optimizer_opts=None,
                         use_earlystopping=True,
@@ -267,6 +273,8 @@ class ClassificationProject(object):
         self.scaler_type = scaler_type
         self.step_signal = step_signal
         self.step_bkg = step_bkg
+        self.stop_train = stop_train
+        self.stop_test = stop_test
         self.optimizer = optimizer
         self.use_earlystopping = use_earlystopping
         self.use_modelcheckpoint = use_modelcheckpoint
@@ -372,19 +380,19 @@ class ClassificationProject(object):
             self.s_train = tree2array(signal_chain,
                                       branches=self.branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
-                                      start=0, step=self.step_signal)
+                                      start=0, step=self.step_signal, stop=self.stop_train)
             self.b_train = tree2array(bkg_chain,
                                       branches=self.branches+[self.weight_expr]+self.identifiers,
                                       selection=self.selection,
-                                      start=0, step=self.step_bkg)
+                                      start=0, step=self.step_bkg, stop=self.stop_train)
             self.s_test = tree2array(signal_chain,
                                      branches=self.branches+[self.weight_expr],
                                      selection=self.selection,
-                                     start=1, step=self.step_signal)
+                                     start=1, step=self.step_signal, stop=self.stop_test)
             self.b_test = tree2array(bkg_chain,
                                      branches=self.branches+[self.weight_expr],
                                      selection=self.selection,
-                                     start=1, step=self.step_bkg)
+                                     start=1, step=self.step_bkg, stop=self.stop_test)
 
             self.rename_fields(self.s_train)
             self.rename_fields(self.b_train)
@@ -566,9 +574,12 @@ class ClassificationProject(object):
             logger.debug("training data before transformation: {}".format(self.x_train))
             logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
             logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
+            orig_copy_setting = self.scaler.copy
+            self.scaler.copy = False
             self.x_train = self.scaler.transform(self.x_train)
             logger.debug("training data after transformation: {}".format(self.x_train))
             self.x_test = self.scaler.transform(self.x_test)
+            self.scaler.copy = orig_copy_setting
             self.data_transformed = True
             logger.info("Training and test data transformed")
 
-- 
GitLab