Skip to content
Snippets Groups Projects
Commit 3b5bb93e authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

transform target for regression

parent 3c3d6c50
No related branches found
No related tags found
No related merge requests found
......@@ -379,6 +379,7 @@ class ClassificationProject(object):
self._b_eventlist_train = None
self._scaler = None
self._scaler_target = None
self._class_weight = None
self._balanced_class_weight = None
self._model = None
......@@ -621,6 +622,39 @@ class ClassificationProject(object):
return self._scaler
@property
def scaler_target(self):
"same as scaler, but for scaling regression targets"
# create the scaler (and fit to training data) if not existent
if self._scaler_target is None:
filename = os.path.join(self.project_dir, "scaler_target.pkl")
try:
self._scaler_target = joblib.load(filename)
logger.info("Loaded existing scaler from {}".format(filename))
except IOError:
logger.info("Creating new {} for scaling the targets".format(self.scaler_type))
scaler_fit_kwargs = dict()
if self.scaler_type == "StandardScaler":
self._scaler_target = StandardScaler()
elif self.scaler_type == "RobustScaler":
self._scaler_target = RobustScaler()
elif self.scaler_type == "WeightedRobustScaler":
self._scaler_target = WeightedRobustScaler()
scaler_fit_kwargs["weights"] = self.w_train_tot
else:
raise ValueError("Scaler type {} unknown".format(self.scaler_type))
logger.info("Fitting {} to training data".format(self.scaler_type))
orig_copy_setting = self.scaler.copy
self.scaler.copy = False
self._scaler_target.fit(self.y_train, **scaler_fit_kwargs)
# i don't want to scale the classification target here
self._scaler_target.center_[0] = 0.
self._scaler_target.scale_[0] = 1.
self.scaler.copy = orig_copy_setting
joblib.dump(self._scaler_target, filename)
return self._scaler_target
def _batch_transform(self, x, fn, batch_size):
"Transform array in batches, temporarily setting mask_values to nan"
transformed = np.empty(x.shape, dtype=x.dtype)
......@@ -648,6 +682,24 @@ class ClassificationProject(object):
return self.scaler.inverse_transform(x)
def transform_target(self, y, batch_size=10000):
if not self.target_fields:
return y
if self.mask_value is not None:
return self._batch_transform(y, self.scaler_target.transform, batch_size)
else:
return self.scaler_target.transform(y)
def inverse_transform_target(self, y, batch_size=10000):
if not self.target_fields:
return y
if self.mask_value is not None:
return self._batch_transform(y, self.scaler_target.inverse_transform, batch_size)
else:
return self.scaler_target.inverse_transform(y)
@property
def history(self):
params_file = os.path.join(self.project_dir, "history_params.json")
......@@ -873,7 +925,7 @@ class ClassificationProject(object):
idx = self.train_val_idx[1]
x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_val_input = self.get_input_list(self.transform(x_val))
y_val_output = self.get_output_list(y_val)
y_val_output = self.get_output_list(self.transform_target(y_val))
w_val_list = self.get_weight_list(w_val)
return x_val_input, y_val_output, w_val_list
......@@ -884,7 +936,7 @@ class ClassificationProject(object):
idx = self.train_val_idx[0]
x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_train_input = self.get_input_list(self.transform(x_train))
y_train_output = self.get_output_list(y_train)
y_train_output = self.get_output_list(self.transform_target(y_train))
w_train_list = self.get_weight_list(w_train)
return x_train_input, y_train_output, w_train_list
......@@ -955,7 +1007,7 @@ class ClassificationProject(object):
y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
x_input = self.get_input_list(self.transform(x_batch))
y_output = self.get_output_list(y_batch)
y_output = self.get_output_list(self.transform_target(y_batch))
w_list = self.get_weight_list(w_batch)
yield (x_input, y_output, w_list)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment