Skip to content
Snippets Groups Projects
Commit 9d83f5eb authored by Nikolai's avatar Nikolai
Browse files

trying to consistently treat masking

parent c055faa6
No related branches found
No related tags found
No related merge requests found
......@@ -32,8 +32,15 @@ def save_show(plt, fig, filename):
return None
def get_mean_event(x, y, class_label):
return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])]
def get_mean_event(x, y, class_label, mask_value=None):
means = []
for var_index in range(x.shape[1]):
if mask_value is not None:
masked_values = np.where(x[:,var_index] == mask_value)[0]
x = x[masked_values]
y = y[masked_values]
means.append(np.mean(x[y==class_label][:,var_index]))
return means
def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None):
......
......@@ -73,13 +73,21 @@ else:
varx_label = args.varx
vary_label = args.vary
# percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
# percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)]
percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights)
percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights)
try:
mask_value = c.mask_value
except NameError:
mask_value = None
varx_test = c.x_test[:,varx_index]
vary_test = c.x_test[:,vary_index]
x_not_masked = np.where(varx_test != mask_value)[0]
y_not_masked = np.where(vary_test != mask_value)[0]
percentilesx = weighted_quantile(varx_test[x_not_masked], [0.1, 0.99], sample_weight=total_weights[x_not_masked])
percentilesy = weighted_quantile(vary_test[y_not_masked], [0.1, 0.99], sample_weight=total_weights[y_not_masked])
if args.xrange is not None:
if len(args.xrange) < 3:
......@@ -100,9 +108,11 @@ else:
if args.mode.startswith("mean"):
if args.mode == "mean_sig":
means = get_mean_event(c.x_test, c.y_test, 1)
means = get_mean_event(c.x_test, c.y_test, 1, mask_value=mask_value)
elif args.mode == "mean_bkg":
means = get_mean_event(c.x_test, c.y_test, 0)
means = get_mean_event(c.x_test, c.y_test, 0, mask_value=mask_value)
print(means)
if hasattr(c, "get_input_list"):
input_transform = c.get_input_list
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment