Skip to content
Snippets Groups Projects
Commit e67e13f2 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

use mse by default for regression

parent e4212c7c
No related branches found
No related tags found
No related merge requests found
...@@ -182,7 +182,7 @@ class ClassificationProject(object): ...@@ -182,7 +182,7 @@ class ClassificationProject(object):
the first time. This seed (increased by one) is used again before the first time. This seed (increased by one) is used again before
training when keras shuffling is used. training when keras shuffling is used.
:param loss: loss function name :param loss: loss function name (or list of names in case of regression targets)
:param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets) :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets)
...@@ -346,6 +346,9 @@ class ClassificationProject(object): ...@@ -346,6 +346,9 @@ class ClassificationProject(object):
self.shuffle_seed = shuffle_seed self.shuffle_seed = shuffle_seed
self.balance_dataset = balance_dataset self.balance_dataset = balance_dataset
self.loss = loss self.loss = loss
if self.regression_branches and (not isinstance(self.loss, list)):
self.loss = [self.loss]+["mean_squared_error"]*len(self.regression_branches)
self.mask_value = mask_value self.mask_value = mask_value
self.apply_class_weight = apply_class_weight self.apply_class_weight = apply_class_weight
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
...@@ -865,7 +868,9 @@ class ClassificationProject(object): ...@@ -865,7 +868,9 @@ class ClassificationProject(object):
idx = self.train_val_idx[1] idx = self.train_val_idx[1]
x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_val_input = self.get_input_list(self.transform(x_val)) x_val_input = self.get_input_list(self.transform(x_val))
return x_val_input, y_val, w_val y_val_output = self.get_output_list(y_val)
w_val_list = self.get_weight_list(w_val)
return x_val_input, y_val_output, w_val_list
@property @property
...@@ -874,7 +879,9 @@ class ClassificationProject(object): ...@@ -874,7 +879,9 @@ class ClassificationProject(object):
idx = self.train_val_idx[0] idx = self.train_val_idx[0]
x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx]
x_train_input = self.get_input_list(self.transform(x_train)) x_train_input = self.get_input_list(self.transform(x_train))
return x_train_input, y_train, w_train y_train_output = self.get_output_list(y_train)
w_train_list = self.get_weight_list(w_train)
return x_train_input, y_train_output, w_train_list
@property @property
...@@ -917,6 +924,14 @@ class ClassificationProject(object): ...@@ -917,6 +924,14 @@ class ClassificationProject(object):
return np.hsplit(y, len(self.target_fields)+1) return np.hsplit(y, len(self.target_fields)+1)
def get_weight_list(self, w):
"Repeat weight n times for regression targets"
if not self.target_fields:
return w
else:
return [w]*(len(self.target_fields)+1)
def yield_batch(self): def yield_batch(self):
"Batch generator - optionally shuffle the indices after each epoch" "Batch generator - optionally shuffle the indices after each epoch"
x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot
...@@ -936,7 +951,8 @@ class ClassificationProject(object): ...@@ -936,7 +951,8 @@ class ClassificationProject(object):
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
x_input = self.get_input_list(self.transform(x_batch)) x_input = self.get_input_list(self.transform(x_batch))
y_output = self.get_output_list(y_batch) y_output = self.get_output_list(y_batch)
yield (x_input, y_output, w_batch) w_list = self.get_weight_list(w_batch)
yield (x_input, y_output, w_list)
def yield_single_class_batch(self, class_label): def yield_single_class_batch(self, class_label):
...@@ -955,12 +971,9 @@ class ClassificationProject(object): ...@@ -955,12 +971,9 @@ class ClassificationProject(object):
shuffled_idx = class_idx shuffled_idx = class_idx
# yield them batch wise # yield them batch wise
for start in range(0, len(shuffled_idx), int(self.batch_size/2)): for start in range(0, len(shuffled_idx), int(self.batch_size/2)):
x_batch = x_train[shuffled_idx[start:start+int(self.batch_size/2)]] yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]],
y_batch = y_train[shuffled_idx[start:start+int(self.batch_size/2)]] y_train[shuffled_idx[start:start+int(self.batch_size/2)]],
w_batch = w_train[shuffled_idx[start:start+int(self.batch_size/2)]] w_train[shuffled_idx[start:start+int(self.batch_size/2)]])
x_input = self.get_input_list(self.transform(x_batch))
y_output = self.get_output_list(y_batch)
yield (x_input, y_output, w_batch)
def yield_balanced_batch(self): def yield_balanced_batch(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment