Skip to content
Snippets Groups Projects
Commit 89ff9647 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Merge branch 'dev-mask'

parents b041276b ef0d8a48
No related branches found
No related tags found
No related merge requests found
......@@ -184,6 +184,8 @@ class ClassificationProject(object):
:param loss: loss function name
:param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
"""
......@@ -246,7 +248,8 @@ class ClassificationProject(object):
tensorboard_opts=None,
random_seed=1234,
balance_dataset=False,
loss='binary_crossentropy'):
loss='binary_crossentropy',
mask_value=None):
self.name = name
self.signal_trees = signal_trees
......@@ -317,6 +320,7 @@ class ClassificationProject(object):
self.random_seed = random_seed
self.balance_dataset = balance_dataset
self.loss = loss
self.mask_value = mask_value
self.s_train = None
self.b_train = None
......@@ -565,12 +569,31 @@ class ClassificationProject(object):
return self._scaler
def transform(self, x):
return self.scaler.transform(x)
def _batch_transform(self, x, fn, batch_size):
"Transform array in batches, temporarily setting mask_values to nan"
transformed = np.empty(x.shape, dtype=x.dtype)
for start in range(0, len(x), batch_size):
stop = start+batch_size
x_batch = np.array(x[start:stop]) # copy
x_batch[x_batch == self.mask_value] = np.nan
x_batch = fn(x_batch)
x_batch[np.isnan(x_batch)] = self.mask_value
transformed[start:stop] = x_batch
return transformed
def transform(self, x, batch_size=10000):
if self.mask_value is not None:
return self._batch_transform(x, self.scaler.transform, batch_size)
else:
return self.scaler.transform(x)
def inverse_transform(self, x):
return self.scaler.inverse_transform(x)
def inverse_transform(self, x, batch_size=10000):
if self.mask_value is not None:
return self._batch_transform(x, self.scaler.inverse_transform, batch_size)
else:
return self.scaler.inverse_transform(x)
@property
......@@ -606,6 +629,9 @@ class ClassificationProject(object):
def _transform_data(self):
if not self.data_transformed:
if self.mask_value is not None:
self.x_train[self.x_train == self.mask_value] = np.nan
self.x_test[self.x_test == self.mask_value] = np.nan
if logger.level <= logging.DEBUG:
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
......@@ -618,6 +644,9 @@ class ClassificationProject(object):
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.scaler.copy = orig_copy_setting
if self.mask_value is not None:
self.x_train[np.isnan(self.x_train)] = self.mask_value
self.x_test[np.isnan(self.x_test)] = self.mask_value
self.data_transformed = True
logger.info("Training and test data transformed")
......@@ -1651,14 +1680,6 @@ class ClassificationProjectRNN(ClassificationProject):
)
def _transform_data(self):
self.x_train[self.x_train == self.mask_value] = np.nan
self.x_test[self.x_test == self.mask_value] = np.nan
super(ClassificationProjectRNN, self)._transform_data()
self.x_train[np.isnan(self.x_train)] = self.mask_value
self.x_test[np.isnan(self.x_test)] = self.mask_value
@property
def model(self):
if self._model is None:
......@@ -1827,27 +1848,6 @@ class ClassificationProjectRNN(ClassificationProject):
eval_score("train")
def _batch_transform(self, x, fn, batch_size):
"Transform array in batches, temporarily setting mask_values to nan"
transformed = np.empty(x.shape, dtype=x.dtype)
for start in range(0, len(x), batch_size):
stop = start+batch_size
x_batch = np.array(x[start:stop]) # copy
x_batch[x_batch == self.mask_value] = np.nan
x_batch = fn(x_batch)
x_batch[np.isnan(x_batch)] = self.mask_value
transformed[start:stop] = x_batch
return transformed
def transform(self, x, batch_size=10000):
return self._batch_transform(x, self.scaler.transform, batch_size)
def inverse_transform(self, x, batch_size=10000):
return self._batch_transform(x, self.scaler.inverse_transform, batch_size)
def evaluate(self, x_eval, mode=None):
logger.debug("Evaluate score for {}".format(x_eval))
x_eval = self.transform(x_eval)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment