Skip to content
Snippets Groups Projects
Commit 9dbc0793 authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

Do transformations only in batch generator - not globally

parent 778a981d
No related branches found
No related tags found
No related merge requests found
......@@ -377,7 +377,6 @@ class ClassificationProject(object):
self.total_epochs = 0
self.data_loaded = False
self.data_transformed = False
# track if we are currently training
self.is_training = False
......@@ -649,30 +648,6 @@ class ClassificationProject(object):
json.dump(self.history.history, of)
def _transform_data(self):
if not self.data_transformed:
if self.mask_value is not None:
self.x_train[self.x_train == self.mask_value] = np.nan
self.x_test[self.x_test == self.mask_value] = np.nan
if logger.level <= logging.DEBUG:
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
for i in range(self.x_train.shape[1])]))
logger.debug("maximum values: {}".format([np.max(self.x_train[:,i][~np.isnan(self.x_train[:,i])])
for i in range(self.x_train.shape[1])]))
orig_copy_setting = self.scaler.copy
self.scaler.copy = False
self.x_train = self.scaler.transform(self.x_train)
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.scaler.copy = orig_copy_setting
if self.mask_value is not None:
self.x_train[np.isnan(self.x_train)] = self.mask_value
self.x_test[np.isnan(self.x_test)] = self.mask_value
self.data_transformed = True
logger.info("Training and test data transformed")
def _read_info(self, key, default):
filename = os.path.join(self.project_dir, "info.json")
if not os.path.exists(filename):
......@@ -812,14 +787,10 @@ class ClassificationProject(object):
if reload:
self.data_loaded = False
self.data_transformed = False
if not self.data_loaded:
self._load_data()
if not self.data_transformed:
self._transform_data()
@property
def w_train_tot(self):
......@@ -842,19 +813,19 @@ class ClassificationProject(object):
@property
def validation_data(self):
"Validation data for loss evaluation"
"(Transformed) validation data for loss evaluation"
idx = self.train_val_idx[1]
x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_val_input = self.get_input_list(x_val)
x_val_input = self.get_input_list(self.transform(x_val))
return x_val_input, y_val, w_val
@property
def training_data(self):
"Training data with validation data split off"
"(Transformed) Training data with validation data split off"
idx = self.train_val_idx[0]
x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_train_input = self.get_input_list(x_train)
x_train_input = self.get_input_list(self.transform(x_train))
return x_train_input, y_train, w_train
......@@ -907,7 +878,7 @@ class ClassificationProject(object):
x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
x_input = self.get_input_list(x_batch)
x_input = self.get_input_list(self.transform(x_batch))
yield (x_input, y_batch, w_batch)
......@@ -1006,8 +977,6 @@ class ClassificationProject(object):
def evaluate_train_test(self, do_train=True, do_test=True, mode=None):
logger.info("Reloading (and re-transforming) training data")
self.load(reload=True)
if mode is not None:
self._write_info("scores_mode", mode)
......@@ -1189,8 +1158,6 @@ class ClassificationProject(object):
ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
label = branch
if self.data_transformed:
label += " (transformed)"
ax.set_xlabel(label)
if fig is not None:
plot_dir = os.path.join(self.project_dir, "plots")
......@@ -1699,7 +1666,6 @@ class ClassificationProjectDataFrame(ClassificationProject):
if reload:
self.data_loaded = False
self.data_transformed = False
self._x_train = None
self._x_test = None
self._y_train = None
......@@ -1710,9 +1676,6 @@ class ClassificationProjectDataFrame(ClassificationProject):
self.data_loaded = True
if not self.data_transformed:
self._transform_data()
class ClassificationProjectRNN(ClassificationProject):
......@@ -1875,8 +1838,6 @@ class ClassificationProjectRNN(ClassificationProject):
def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None):
logger.info("Reloading (and re-transforming) unshuffled training data")
self.load(reload=True)
if mode is not None:
self._write_info("scores_mode", mode)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment