diff --git a/toolkit.py b/toolkit.py index ec074bdf0c9e06e32de108a40fbb0e9641177488..a97a84671840ffb675f8962dab6784cc77f004c4 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1164,17 +1164,51 @@ class ClassificationProject(object): return centers, hist, errors - def plot_input(self, var_index, ax=None): - "plot a single input variable" + def plot_input(self, var_index, ax=None, from_training_batches=False, max_n_batches=None): + """ + plot a single input variable as a histogram (signal vs background) + + :param from_training_batches: use data from training batch generator + :param max_n_batches: if training batch generator is used, just use + this number of batches (otherwise steps_per_epoch is used) + """ branch = self.fields[var_index] if ax is None: fig, ax = plt.subplots() else: fig = None - bkg = self.x_train[:,var_index][self.y_train == 0] - sig = self.x_train[:,var_index][self.y_train == 1] - bkg_weights = self.w_train_tot[self.y_train == 0] - sig_weights = self.w_train_tot[self.y_train == 1] + + if not from_training_batches: + bkg = self.x_train[:,var_index][self.y_train == 0] + sig = self.x_train[:,var_index][self.y_train == 1] + bkg_weights = self.w_train_tot[self.y_train == 0] + sig_weights = self.w_train_tot[self.y_train == 1] + else: + bkg = None + sig = None + bkg_weights = None + sig_weights = None + if max_n_batches is not None: + n_batches = max_n_batches + else: + n_batches = self.steps_per_epoch + for i_batch, (x, y, w) in enumerate(self.yield_batch()): + if i_batch > n_batches: + break + bkg_batch = x[:,var_index][y==0] + sig_batch = x[:,var_index][y==1] + bkg_weights_batch = w[y==0] + sig_weights_batch = w[y==1] + if bkg is None: + bkg = bkg_batch + sig = sig_batch + bkg_weights = bkg_weights_batch + sig_weights = sig_weights_batch + else: + bkg = np.concatenate([bkg, bkg_batch]) + sig = np.concatenate([sig, sig_batch]) + bkg_weights = np.concatenate([bkg_weights, bkg_weights_batch]) + sig_weights = np.concatenate([sig_weights, sig_weights_batch]) if hasattr(self, "mask_value"): bkg_not_masked = np.where(bkg != self.mask_value)[0] @@ -1238,13 +1272,13 @@ class ClassificationProject(object): return save_show(plt, fig, os.path.join(plot_dir, "var_{}.pdf".format(var_index))) - def plot_all_inputs(self): + def plot_all_inputs(self, **kwargs): nrows = math.ceil(math.sqrt(len(self.fields))) fig, axes = plt.subplots(nrows=int(nrows), ncols=int(nrows), figsize=(3*nrows, 3*nrows), gridspec_kw=dict(wspace=0.4, hspace=0.4)) for i in range(len(self.fields)): - self.plot_input(i, ax=axes.reshape(-1)[i]) + self.plot_input(i, ax=axes.reshape(-1)[i], **kwargs) return save_show(plt, fig, os.path.join(self.project_dir, "all_inputs.pdf"))