From e67e13f254bf44a126f912d76c473f6c9ab128d9 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Mon, 19 Nov 2018 16:14:30 +0100 Subject: [PATCH] use mse by default for regression --- toolkit.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/toolkit.py b/toolkit.py index 620d3f6..3bda85e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -182,7 +182,7 @@ class ClassificationProject(object): the first time. This seed (increased by one) is used again before training when keras shuffling is used. - :param loss: loss function name + :param loss: loss function name (or list of names in case of regression targets) :param mask_value: value that is used for non-existent entries (e.g. 4th jet pt in events with 3 jets) @@ -346,6 +346,9 @@ class ClassificationProject(object): self.shuffle_seed = shuffle_seed self.balance_dataset = balance_dataset self.loss = loss + if self.regression_branches and (not isinstance(self.loss, list)): + self.loss = [self.loss]+["mean_squared_error"]*len(self.regression_branches) + self.mask_value = mask_value self.apply_class_weight = apply_class_weight self.normalize_weights = normalize_weights @@ -865,7 +868,9 @@ class ClassificationProject(object): idx = self.train_val_idx[1] x_val, y_val, w_val = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_val_input = self.get_input_list(self.transform(x_val)) - return x_val_input, y_val, w_val + y_val_output = self.get_output_list(y_val) + w_val_list = self.get_weight_list(w_val) + return x_val_input, y_val_output, w_val_list @property @@ -874,7 +879,9 @@ class ClassificationProject(object): idx = self.train_val_idx[0] x_train, y_train, w_train = self.x_train[idx], self.y_train[idx], self.w_train_tot[idx] x_train_input = self.get_input_list(self.transform(x_train)) - return x_train_input, y_train, w_train + y_train_output = self.get_output_list(y_train) + w_train_list = self.get_weight_list(w_train) + return x_train_input, y_train_output, w_train_list @property @@ -917,6 +924,14 @@ class ClassificationProject(object): return np.hsplit(y, len(self.target_fields)+1) + def get_weight_list(self, w): + "Repeat weight n times for regression targets" + if not self.target_fields: + return w + else: + return [w]*(len(self.target_fields)+1) + + def yield_batch(self): "Batch generator - optionally shuffle the indices after each epoch" x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot @@ -936,7 +951,8 @@ class ClassificationProject(object): w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]] x_input = self.get_input_list(self.transform(x_batch)) y_output = self.get_output_list(y_batch) - yield (x_input, y_output, w_batch) + w_list = self.get_weight_list(w_batch) + yield (x_input, y_output, w_list) def yield_single_class_batch(self, class_label): @@ -955,12 +971,9 @@ class ClassificationProject(object): shuffled_idx = class_idx # yield them batch wise for start in range(0, len(shuffled_idx), int(self.batch_size/2)): - x_batch = x_train[shuffled_idx[start:start+int(self.batch_size/2)]] - y_batch = y_train[shuffled_idx[start:start+int(self.batch_size/2)]] - w_batch = w_train[shuffled_idx[start:start+int(self.batch_size/2)]] - x_input = self.get_input_list(self.transform(x_batch)) - y_output = self.get_output_list(y_batch) - yield (x_input, y_output, w_batch) + yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]], + y_train[shuffled_idx[start:start+int(self.batch_size/2)]], + w_train[shuffled_idx[start:start+int(self.batch_size/2)]]) def yield_balanced_batch(self): -- GitLab