Skip to content
Snippets Groups Projects
Commit f205e49d authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

scaling hists for input plots by class weight instead of multiplying to sample weights

parent 9879ac31
No related branches found
No related tags found
No related merge requests found
......@@ -198,8 +198,6 @@ class KerasROOTClassification(object):
self._scaler = None
self._class_weight = None
self._bkg_weights = None
self._sig_weights = None
self._model = None
self._history = None
self._callbacks_list = []
......@@ -586,34 +584,12 @@ class KerasROOTClassification(object):
pass
@property
def bkg_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._bkg_weights is None:
logger.debug("Calculating background weights for plotting")
self._bkg_weights = np.empty(sum(self.y_train == 0))
self._bkg_weights.fill(self.class_weight[0])
self._bkg_weights *= self.w_train[self.y_train == 0]
logger.debug("Background weights: {}".format(self._bkg_weights))
return self._bkg_weights
@property
def sig_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._sig_weights is None:
logger.debug("Calculating signal weights for plotting")
self._sig_weights = np.empty(sum(self.y_train == 1))
self._sig_weights.fill(self.class_weight[1])
self._sig_weights *= self.w_train[self.y_train == 1]
logger.debug("Signal weights: {}".format(self._sig_weights))
return self._sig_weights
@staticmethod
def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
hist, bins = np.histogram(x, **np_kwargs)
centers = (bins[:-1] + bins[1:]) / 2
hist *= scale_factor
return centers, hist
def plot_input(self, var_index):
......@@ -622,6 +598,8 @@ class KerasROOTClassification(object):
fig, ax = plt.subplots()
bkg = self.x_train[:,var_index][self.y_train == 0]
sig = self.x_train[:,var_index][self.y_train == 1]
bkg_weights = self.w_train[self.y_train == 0]
sig_weights = self.w_train[self.y_train == 1]
logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
......@@ -635,14 +613,19 @@ class KerasROOTClassification(object):
logger.debug("Calculated range based on percentiles: {}".format(plot_range))
try:
ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
except ValueError:
# weird, probably not always working workaround for a numpy bug
plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
width = centers_sig[1]-centers_sig[0]
ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
ax.set_xlabel(branch+" (transformed)")
plot_dir = os.path.join(self.project_dir, "plots")
if not os.path.exists(plot_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment