Skip to content
Snippets Groups Projects
Commit c77d64ff authored by Nikolai's avatar Nikolai
Browse files

put some masking functionality into base class

parent 9430ad7c
No related branches found
No related tags found
No related merge requests found
......@@ -183,6 +183,8 @@ class ClassificationProject(object):
:param loss: loss function name
:param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
"""
......@@ -245,7 +247,8 @@ class ClassificationProject(object):
tensorboard_opts=None,
random_seed=1234,
balance_dataset=False,
loss='binary_crossentropy'):
loss='binary_crossentropy',
mask_value=None):
self.name = name
self.signal_trees = signal_trees
......@@ -316,6 +319,7 @@ class ClassificationProject(object):
self.random_seed = random_seed
self.balance_dataset = balance_dataset
self.loss = loss
self.mask_value = mask_value
self.s_train = None
self.b_train = None
......@@ -562,12 +566,31 @@ class ClassificationProject(object):
return self._scaler
def transform(self, x):
return self.scaler.transform(x)
def _batch_transform(self, x, fn, batch_size):
"Transform array in batches, temporarily setting mask_values to nan"
transformed = np.empty(x.shape, dtype=x.dtype)
for start in range(0, len(x), batch_size):
stop = start+batch_size
x_batch = np.array(x[start:stop]) # copy
x_batch[x_batch == self.mask_value] = np.nan
x_batch = fn(x_batch)
x_batch[np.isnan(x_batch)] = self.mask_value
transformed[start:stop] = x_batch
return transformed
def transform(self, x, batch_size=10000):
if self.mask_value is not None:
return self._batch_transform(x, self.scaler.transform, batch_size)
else:
return self.scaler.transform(x)
def inverse_transform(self, x):
return self.scaler.inverse_transform(x)
def inverse_transform(self, x, batch_size=10000):
if self.mask_value is not None:
return self._batch_transform(x, self.scaler.inverse_transform, batch_size)
else:
return self.scaler.inverse_transform(x)
@property
......@@ -603,6 +626,9 @@ class ClassificationProject(object):
def _transform_data(self):
if not self.data_transformed:
if self.mask_value is not None:
self.x_train[self.x_train == self.mask_value] = np.nan
self.x_test[self.x_test == self.mask_value] = np.nan
if logger.level <= logging.DEBUG:
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
......@@ -615,6 +641,9 @@ class ClassificationProject(object):
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.scaler.copy = orig_copy_setting
if self.mask_value is not None:
self.x_train[np.isnan(self.x_train)] = self.mask_value
self.x_test[np.isnan(self.x_test)] = self.mask_value
self.data_transformed = True
logger.info("Training and test data transformed")
......@@ -1645,14 +1674,6 @@ class ClassificationProjectRNN(ClassificationProject):
)
def _transform_data(self):
self.x_train[self.x_train == self.mask_value] = np.nan
self.x_test[self.x_test == self.mask_value] = np.nan
super(ClassificationProjectRNN, self)._transform_data()
self.x_train[np.isnan(self.x_train)] = self.mask_value
self.x_test[np.isnan(self.x_test)] = self.mask_value
@property
def model(self):
if self._model is None:
......@@ -1784,27 +1805,6 @@ class ClassificationProjectRNN(ClassificationProject):
eval_score("train")
def _batch_transform(self, x, fn, batch_size):
"Transform array in batches, temporarily setting mask_values to nan"
transformed = np.empty(x.shape, dtype=x.dtype)
for start in range(0, len(x), batch_size):
stop = start+batch_size
x_batch = np.array(x[start:stop]) # copy
x_batch[x_batch == self.mask_value] = np.nan
x_batch = fn(x_batch)
x_batch[np.isnan(x_batch)] = self.mask_value
transformed[start:stop] = x_batch
return transformed
def transform(self, x, batch_size=10000):
return self._batch_transform(x, self.scaler.transform, batch_size)
def inverse_transform(self, x, batch_size=10000):
return self._batch_transform(x, self.scaler.inverse_transform, batch_size)
def evaluate(self, x_eval, mode=None):
logger.debug("Evaluate score for {}".format(x_eval))
x_eval = self.transform(x_eval)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment