Skip to content
Snippets Groups Projects
Commit 9dbc0793 authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

Do transformations only in batch generator - not globally

parent 778a981d
No related branches found
No related tags found
No related merge requests found
...@@ -377,7 +377,6 @@ class ClassificationProject(object): ...@@ -377,7 +377,6 @@ class ClassificationProject(object):
self.total_epochs = 0 self.total_epochs = 0
self.data_loaded = False self.data_loaded = False
self.data_transformed = False
# track if we are currently training # track if we are currently training
self.is_training = False self.is_training = False
...@@ -649,30 +648,6 @@ class ClassificationProject(object): ...@@ -649,30 +648,6 @@ class ClassificationProject(object):
json.dump(self.history.history, of) json.dump(self.history.history, of)
def _transform_data(self):
if not self.data_transformed:
if self.mask_value is not None:
self.x_train[self.x_train == self.mask_value] = np.nan
self.x_test[self.x_test == self.mask_value] = np.nan
if logger.level <= logging.DEBUG:
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
for i in range(self.x_train.shape[1])]))
logger.debug("maximum values: {}".format([np.max(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
for i in range(self.x_train.shape[1])]))
orig_copy_setting = self.scaler.copy
self.scaler.copy = False
self.x_train = self.scaler.transform(self.x_train)
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.scaler.copy = orig_copy_setting
if self.mask_value is not None:
self.x_train[np.isnan(self.x_train)] = self.mask_value
self.x_test[np.isnan(self.x_test)] = self.mask_value
self.data_transformed = True
logger.info("Training and test data transformed")
def _read_info(self, key, default): def _read_info(self, key, default):
filename = os.path.join(self.project_dir, "info.json") filename = os.path.join(self.project_dir, "info.json")
if not os.path.exists(filename): if not os.path.exists(filename):
...@@ -812,14 +787,10 @@ class ClassificationProject(object): ...@@ -812,14 +787,10 @@ class ClassificationProject(object):
if reload: if reload:
self.data_loaded = False self.data_loaded = False
self.data_transformed = False
if not self.data_loaded: if not self.data_loaded:
self._load_data() self._load_data()
if not self.data_transformed:
self._transform_data()
@property @property
def w_train_tot(self): def w_train_tot(self):
...@@ -842,19 +813,19 @@ class ClassificationProject(object): ...@@ -842,19 +813,19 @@ class ClassificationProject(object):
@property @property
def validation_data(self): def validation_data(self):
"Validation data for loss evaluation" "(Transformed) validation data for loss evaluation"
idx = self.train_val_idx[1] idx = self.train_val_idx[1]
x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_val_input = self.get_input_list(x_val) x_val_input = self.get_input_list(self.transform(x_val))
return x_val_input, y_val, w_val return x_val_input, y_val, w_val
@property @property
def training_data(self): def training_data(self):
"Training data with validation data split off" "(Transformed) Training data with validation data split off"
idx = self.train_val_idx[0] idx = self.train_val_idx[0]
x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_train_input = self.get_input_list(x_train) x_train_input = self.get_input_list(self.transform(x_train))
return x_train_input, y_train, w_train return x_train_input, y_train, w_train
...@@ -907,7 +878,7 @@ class ClassificationProject(object): ...@@ -907,7 +878,7 @@ class ClassificationProject(object):
x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]] x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]] y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
x_input = self.get_input_list(x_batch) x_input = self.get_input_list(self.transform(x_batch))
yield (x_input, y_batch, w_batch) yield (x_input, y_batch, w_batch)
...@@ -1006,8 +977,6 @@ class ClassificationProject(object): ...@@ -1006,8 +977,6 @@ class ClassificationProject(object):
def evaluate_train_test(self, do_train=True, do_test=True, mode=None): def evaluate_train_test(self, do_train=True, do_test=True, mode=None):
logger.info("Reloading (and re-transforming) training data")
self.load(reload=True)
if mode is not None: if mode is not None:
self._write_info("scores_mode", mode) self._write_info("scores_mode", mode)
...@@ -1189,8 +1158,6 @@ class ClassificationProject(object): ...@@ -1189,8 +1158,6 @@ class ClassificationProject(object):
ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width) ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
label = branch label = branch
if self.data_transformed:
label += " (transformed)"
ax.set_xlabel(label) ax.set_xlabel(label)
if fig is not None: if fig is not None:
plot_dir = os.path.join(self.project_dir, "plots") plot_dir = os.path.join(self.project_dir, "plots")
...@@ -1699,7 +1666,6 @@ class ClassificationProjectDataFrame(ClassificationProject): ...@@ -1699,7 +1666,6 @@ class ClassificationProjectDataFrame(ClassificationProject):
if reload: if reload:
self.data_loaded = False self.data_loaded = False
self.data_transformed = False
self._x_train = None self._x_train = None
self._x_test = None self._x_test = None
self._y_train = None self._y_train = None
...@@ -1710,9 +1676,6 @@ class ClassificationProjectDataFrame(ClassificationProject): ...@@ -1710,9 +1676,6 @@ class ClassificationProjectDataFrame(ClassificationProject):
self.data_loaded = True self.data_loaded = True
if not self.data_transformed:
self._transform_data()
class ClassificationProjectRNN(ClassificationProject): class ClassificationProjectRNN(ClassificationProject):
...@@ -1875,8 +1838,6 @@ class ClassificationProjectRNN(ClassificationProject): ...@@ -1875,8 +1838,6 @@ class ClassificationProjectRNN(ClassificationProject):
def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None): def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None):
logger.info("Reloading (and re-transforming) unshuffled training data")
self.load(reload=True)
if mode is not None: if mode is not None:
self._write_info("scores_mode", mode) self._write_info("scores_mode", mode)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment