From f205e49d422279bcccc88fb642c52d9e2dc4c1ab Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Tue, 8 May 2018 17:42:14 +0200 Subject: [PATCH] scaling hists for input plots by class weight instead of multiplying to sample weights --- toolkit.py | 51 +++++++++++++++++---------------------------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/toolkit.py b/toolkit.py index ee64a09..205b3ca 100755 --- a/toolkit.py +++ b/toolkit.py @@ -198,8 +198,6 @@ class KerasROOTClassification(object): self._scaler = None self._class_weight = None - self._bkg_weights = None - self._sig_weights = None self._model = None self._history = None self._callbacks_list = [] @@ -586,34 +584,12 @@ class KerasROOTClassification(object): pass - @property - def bkg_weights(self): - """ - class weights multiplied by event weights (for plotting) - TODO: find a better way to do this - """ - if self._bkg_weights is None: - logger.debug("Calculating background weights for plotting") - self._bkg_weights = np.empty(sum(self.y_train == 0)) - self._bkg_weights.fill(self.class_weight[0]) - self._bkg_weights *= self.w_train[self.y_train == 0] - logger.debug("Background weights: {}".format(self._bkg_weights)) - return self._bkg_weights - - - @property - def sig_weights(self): - """ - class weights multiplied by event weights (for plotting) - TODO: find a better way to do this - """ - if self._sig_weights is None: - logger.debug("Calculating signal weights for plotting") - self._sig_weights = np.empty(sum(self.y_train == 1)) - self._sig_weights.fill(self.class_weight[1]) - self._sig_weights *= self.w_train[self.y_train == 1] - logger.debug("Signal weights: {}".format(self._sig_weights)) - return self._sig_weights + @staticmethod + def get_bin_centered_hist(x, scale_factor=None, **np_kwargs): + hist, bins = np.histogram(x, **np_kwargs) + centers = (bins[:-1] + bins[1:]) / 2 + hist *= scale_factor + return centers, hist def plot_input(self, var_index): @@ -622,6 +598,8 @@ class KerasROOTClassification(object): fig, ax = plt.subplots() bkg = self.x_train[:,var_index][self.y_train == 0] sig = self.x_train[:,var_index][self.y_train == 1] + bkg_weights = self.w_train[self.y_train == 0] + sig_weights = self.w_train[self.y_train == 1] logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg)) logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig)) @@ -635,14 +613,19 @@ class KerasROOTClassification(object): logger.debug("Calculated range based on percentiles: {}".format(plot_range)) try: - ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights) - ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights) + centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) + centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) except ValueError: # weird, probably not always working workaround for a numpy bug plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1]))) logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range)) - ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights) - ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights) + centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) + centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) + + width = centers_sig[1]-centers_sig[0] + ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width) + ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width) + ax.set_xlabel(branch+" (transformed)") plot_dir = os.path.join(self.project_dir, "plots") if not os.path.exists(plot_dir): -- GitLab