diff --git a/toolkit.py b/toolkit.py index 3e7b5743ab9e0f5aa41ea7312599735ac35d83a7..e25b458e90c62f4e278f31f7a3793565220487e3 100755 --- a/toolkit.py +++ b/toolkit.py @@ -330,8 +330,8 @@ class ClassificationProject(object): self._scores_train = None self._scores_test = None - # class weighted validation data - self._w_validation = None + # class weighted training data (divided by mean) + self._w_train_tot = None self._s_eventlist_train = None self._b_eventlist_train = None @@ -354,8 +354,6 @@ class ClassificationProject(object): self._fields = None - self._mean_train_weight = None - @property def fields(self): @@ -549,7 +547,7 @@ class ClassificationProject(object): self._scaler = RobustScaler() elif self.scaler_type == "WeightedRobustScaler": self._scaler = WeightedRobustScaler() - scaler_fit_kwargs["weights"] = self.w_train*np.array(self.class_weight)[self.y_train.astype(int)] + scaler_fit_kwargs["weights"] = self.w_train_tot else: raise ValueError("Scaler type {} unknown".format(self.scaler_type)) logger.info("Fitting {} to training data".format(self.scaler_type)) @@ -760,6 +758,8 @@ class ClassificationProject(object): np.random.shuffle(self.y_train) np.random.set_state(rn_state) np.random.shuffle(self.w_train) + np.random.set_state(rn_state) + np.random.shuffle(self.w_train_tot) if self._scores_train is not None: logger.info("Shuffling scores, since they are also there") np.random.set_state(rn_state) @@ -767,35 +767,30 @@ class ClassificationProject(object): @property - def mean_train_weight(self): - if self._mean_train_weight is None: - self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]) - return self._mean_train_weight - - - @property - def w_validation(self): - "class weighted validation data weights" - split_index = int((1-self.validation_split)*len(self.x_train)) - if self._w_validation is None: - self._w_validation = np.array(self.w_train[split_index:]) - self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0] - self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1] - return self._w_validation/self.mean_train_weight + def w_train_tot(self): + "(sample weight * class weight), divided by mean" + if not self.balance_dataset: + class_weight = self.class_weight + else: + class_weight = self.balanced_class_weight + if self._w_train_tot is None: + self._w_train_tot = self.w_train*np.array(class_weight)[self.y_train.astype(int)] + #self._w_train_tot /= np.mean(self._w_train_tot) + return self._w_train_tot @property - def class_weighted_validation_data(self): - "class weighted validation data. Attention: Shuffle training data before using this!" + def validation_data(self): + "Validation data. Attention: Shuffle training data before using this!" split_index = int((1-self.validation_split)*len(self.x_train)) - return self.x_train[split_index:], self.y_train[split_index:], self.w_validation + return self.x_train[split_index:], self.y_train[split_index:], self.w_train_tot[split_index:] @property def training_data(self): - "training data with validation data split off. Attention: Shuffle training data before using this!" + "Training data with validation data split off. Attention: Shuffle training data before using this!" split_index = int((1-self.validation_split)*len(self.x_train)) - return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] + return self.x_train[:split_index], self.y_train[:split_index], self.w_train_tot[:split_index] def yield_single_class_batch(self, class_label): @@ -812,7 +807,7 @@ class ClassificationProject(object): for start in range(0, len(shuffled_idx), int(self.batch_size/2)): yield (x_train[shuffled_idx[start:start+int(self.batch_size/2)]], y_train[shuffled_idx[start:start+int(self.batch_size/2)]], - w_train[shuffled_idx[start:start+int(self.batch_size/2)]]*self.balanced_class_weight[class_label]) + w_train[shuffled_idx[start:start+int(self.batch_size/2)]]) def yield_balanced_batch(self): @@ -834,6 +829,8 @@ class ClassificationProject(object): self.load() + self.shuffle_training_data() + for branch_index, branch in enumerate(self.fields): self.plot_input(branch_index) @@ -842,16 +839,15 @@ class ClassificationProject(object): logger.info("Train model") if not self.balance_dataset: try: - self.shuffle_training_data() self.is_training = True self.model.fit(self.x_train, # the reshape might be unnescessary here self.y_train.reshape(-1, 1), epochs=epochs, - validation_split = self.validation_split, + validation_split=self.validation_split, # we have to multiply by class weight since keras ignores class weight if sample weight is given # see https://github.com/keras-team/keras/issues/497 - sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]/self.mean_train_weight, + sample_weight=self.w_train_tot, shuffle=True, batch_size=self.batch_size, callbacks=self.callbacks_list) @@ -860,7 +856,6 @@ class ClassificationProject(object): logger.info("Interrupt training - continue with rest") else: try: - self.shuffle_training_data() # needed here too, in order to get correct validation data self.is_training = True labels, label_counts = np.unique(self.y_train, return_counts=True) logger.info("Training on balanced batches") @@ -868,7 +863,7 @@ class ClassificationProject(object): self.model.fit_generator(self.yield_balanced_batch(), steps_per_epoch=int(min(label_counts)/self.batch_size), epochs=epochs, - validation_data=self.class_weighted_validation_data, + validation_data=self.validation_data, callbacks=self.callbacks_list) self.is_training = False except KeyboardInterrupt: @@ -1012,33 +1007,39 @@ class ClassificationProject(object): fig, ax = plt.subplots() bkg = self.x_train[:,var_index][self.y_train == 0] sig = self.x_train[:,var_index][self.y_train == 1] - bkg_weights = self.w_train[self.y_train == 0] - sig_weights = self.w_train[self.y_train == 1] + bkg_weights = self.w_train_tot[self.y_train == 0] + sig_weights = self.w_train_tot[self.y_train == 1] + + if self.balance_dataset: + if len(sig) < len(bkg): + logger.warning("Plotting only up to {} bkg events, since we use balance_dataset".format(len(sig))) + bkg = bkg[0:len(sig)] + bkg_weights = bkg_weights[0:len(sig)] + else: + logger.warning("Plotting only up to {} sig events, since we use balance_dataset".format(len(bkg))) + sig = sig[0:len(bkg)] + sig_weights = sig_weights[0:len(bkg)] logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg)) logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig)) # calculate percentiles to get a heuristic for the range to be plotted - # (should in principle also be done with weights, but for now do it unweighted) - # range_sig = np.percentile(sig, [1, 99]) - # range_bkg = np.percentile(sig, [1, 99]) - # plot_range = (min(range_sig[0], range_bkg[0]), max(range_sig[1], range_sig[1])) plot_range = weighted_quantile( self.x_train[:,var_index], [0.01, 0.99], - sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)] + sample_weight=self.w_train_tot ) logger.debug("Calculated range based on percentiles: {}".format(plot_range)) try: - centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) - centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) + centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=50, range=plot_range, weights=sig_weights) + centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=50, range=plot_range, weights=bkg_weights) except ValueError: # weird, probably not always working workaround for a numpy bug plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1]))) logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range)) - centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) - centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) + centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=50, range=plot_range, weights=sig_weights) + centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=50, range=plot_range, weights=bkg_weights) width = centers_sig[1]-centers_sig[0] ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width) @@ -1600,7 +1601,7 @@ class ClassificationProjectRNN(ClassificationProject): self.model.fit_generator(self.yield_batch(), steps_per_epoch=int(len(self.training_data[0])/self.batch_size), epochs=epochs, - validation_data=self.class_weighted_validation_data, + validation_data=self.validation_data, callbacks=self.callbacks_list) self.is_training = False except KeyboardInterrupt: @@ -1635,9 +1636,9 @@ class ClassificationProjectRNN(ClassificationProject): @property - def class_weighted_validation_data(self): + def validation_data(self): "class weighted validation data. Attention: Shuffle training data before using this!" - x_val, y_val, w_val = super(ClassificationProjectRNN, self).class_weighted_validation_data + x_val, y_val, w_val = super(ClassificationProjectRNN, self).validation_data x_val_input = self.get_input_list(x_val) return x_val_input, y_val, w_val