Skip to content
Snippets Groups Projects
Commit f205e49d authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

scaling hists for input plots by class weight instead of multiplying to sample weights

parent 9879ac31
No related branches found
No related tags found
No related merge requests found
...@@ -198,8 +198,6 @@ class KerasROOTClassification(object): ...@@ -198,8 +198,6 @@ class KerasROOTClassification(object):
self._scaler = None self._scaler = None
self._class_weight = None self._class_weight = None
self._bkg_weights = None
self._sig_weights = None
self._model = None self._model = None
self._history = None self._history = None
self._callbacks_list = [] self._callbacks_list = []
...@@ -586,34 +584,12 @@ class KerasROOTClassification(object): ...@@ -586,34 +584,12 @@ class KerasROOTClassification(object):
pass pass
@property @staticmethod
def bkg_weights(self): def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
""" hist, bins = np.histogram(x, **np_kwargs)
class weights multiplied by event weights (for plotting) centers = (bins[:-1] + bins[1:]) / 2
TODO: find a better way to do this hist *= scale_factor
""" return centers, hist
if self._bkg_weights is None:
logger.debug("Calculating background weights for plotting")
self._bkg_weights = np.empty(sum(self.y_train == 0))
self._bkg_weights.fill(self.class_weight[0])
self._bkg_weights *= self.w_train[self.y_train == 0]
logger.debug("Background weights: {}".format(self._bkg_weights))
return self._bkg_weights
@property
def sig_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._sig_weights is None:
logger.debug("Calculating signal weights for plotting")
self._sig_weights = np.empty(sum(self.y_train == 1))
self._sig_weights.fill(self.class_weight[1])
self._sig_weights *= self.w_train[self.y_train == 1]
logger.debug("Signal weights: {}".format(self._sig_weights))
return self._sig_weights
def plot_input(self, var_index): def plot_input(self, var_index):
...@@ -622,6 +598,8 @@ class KerasROOTClassification(object): ...@@ -622,6 +598,8 @@ class KerasROOTClassification(object):
fig, ax = plt.subplots() fig, ax = plt.subplots()
bkg = self.x_train[:,var_index][self.y_train == 0] bkg = self.x_train[:,var_index][self.y_train == 0]
sig = self.x_train[:,var_index][self.y_train == 1] sig = self.x_train[:,var_index][self.y_train == 1]
bkg_weights = self.w_train[self.y_train == 0]
sig_weights = self.w_train[self.y_train == 1]
logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg)) logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig)) logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
...@@ -635,14 +613,19 @@ class KerasROOTClassification(object): ...@@ -635,14 +613,19 @@ class KerasROOTClassification(object):
logger.debug("Calculated range based on percentiles: {}".format(plot_range)) logger.debug("Calculated range based on percentiles: {}".format(plot_range))
try: try:
ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights) centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights) centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
except ValueError: except ValueError:
# weird, probably not always working workaround for a numpy bug # weird, probably not always working workaround for a numpy bug
plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1]))) plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range)) logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights) centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights) centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
width = centers_sig[1]-centers_sig[0]
ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
ax.set_xlabel(branch+" (transformed)") ax.set_xlabel(branch+" (transformed)")
plot_dir = os.path.join(self.project_dir, "plots") plot_dir = os.path.join(self.project_dir, "plots")
if not os.path.exists(plot_dir): if not os.path.exists(plot_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment