Skip to content
Snippets Groups Projects
Commit 0e1202f5 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Memory improvements

* stop option for reading a maximum number of events for train/test
* inplace scaling
parent beea1cfe
No related branches found
No related tags found
No related merge requests found
......@@ -133,6 +133,10 @@ class ClassificationProject(object):
:param step_bkg: step size when selecting background training events (e.g. 2 means take every second event)
:param stop_train: stop after this number of events for reading in training events
:param stop_test: stop after this number of events for reading in test events
:param optimizer: name of optimizer class in keras.optimizers
:param optimizer_opts: dictionary of options for the optimizer
......@@ -213,6 +217,8 @@ class ClassificationProject(object):
scaler_type="WeightedRobustScaler",
step_signal=2,
step_bkg=2,
stop_train=None,
stop_test=None,
optimizer="SGD",
optimizer_opts=None,
use_earlystopping=True,
......@@ -267,6 +273,8 @@ class ClassificationProject(object):
self.scaler_type = scaler_type
self.step_signal = step_signal
self.step_bkg = step_bkg
self.stop_train = stop_train
self.stop_test = stop_test
self.optimizer = optimizer
self.use_earlystopping = use_earlystopping
self.use_modelcheckpoint = use_modelcheckpoint
......@@ -372,19 +380,19 @@ class ClassificationProject(object):
self.s_train = tree2array(signal_chain,
branches=self.branches+[self.weight_expr]+self.identifiers,
selection=self.selection,
start=0, step=self.step_signal)
start=0, step=self.step_signal, stop=self.stop_train)
self.b_train = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr]+self.identifiers,
selection=self.selection,
start=0, step=self.step_bkg)
start=0, step=self.step_bkg, stop=self.stop_train)
self.s_test = tree2array(signal_chain,
branches=self.branches+[self.weight_expr],
selection=self.selection,
start=1, step=self.step_signal)
start=1, step=self.step_signal, stop=self.stop_test)
self.b_test = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr],
selection=self.selection,
start=1, step=self.step_bkg)
start=1, step=self.step_bkg, stop=self.stop_test)
self.rename_fields(self.s_train)
self.rename_fields(self.b_train)
......@@ -566,9 +574,12 @@ class ClassificationProject(object):
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
orig_copy_setting = self.scaler.copy
self.scaler.copy = False
self.x_train = self.scaler.transform(self.x_train)
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.scaler.copy = orig_copy_setting
self.data_transformed = True
logger.info("Training and test data transformed")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment