diff --git a/plotting.py b/plotting.py index bfb906c06850f0ddc462909dceb8963bac0b8212..f0dd64e82e9b7726e16772e6cee56b816add778a 100644 --- a/plotting.py +++ b/plotting.py @@ -32,8 +32,15 @@ def save_show(plt, fig, filename): return None -def get_mean_event(x, y, class_label): - return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])] +def get_mean_event(x, y, class_label, mask_value=None): + means = [] + for var_index in range(x.shape[1]): + if mask_value is not None: + masked_values = np.where(x[:,var_index] == mask_value)[0] + x = x[masked_values] + y = y[masked_values] + means.append(np.mean(x[y==class_label][:,var_index])) + return means def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None): diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py index 04016cbd396e9e22ca2f3d76ab9dc5f167911a24..460934daca75730d6f3d0df7b2d881af6c1595af 100755 --- a/scripts/plot_NN_2D.py +++ b/scripts/plot_NN_2D.py @@ -73,13 +73,21 @@ else: varx_label = args.varx vary_label = args.vary -# percentilesx = np.percentile(c.x_test[:,varx_index], [1,99]) -# percentilesy = np.percentile(c.x_test[:,vary_index], [1,99]) - total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)] -percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights) -percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights) +try: + mask_value = c.mask_value +except NameError: + mask_value = None + +varx_test = c.x_test[:,varx_index] +vary_test = c.x_test[:,vary_index] + +x_not_masked = np.where(varx_test != mask_value)[0] +y_not_masked = np.where(vary_test != mask_value)[0] + +percentilesx = weighted_quantile(varx_test[x_not_masked], [0.1, 0.99], sample_weight=total_weights[x_not_masked]) +percentilesy = weighted_quantile(vary_test[y_not_masked], [0.1, 0.99], sample_weight=total_weights[y_not_masked]) if args.xrange is not None: if len(args.xrange) < 3: @@ -100,9 +108,11 @@ else: if args.mode.startswith("mean"): if args.mode == "mean_sig": - means = get_mean_event(c.x_test, c.y_test, 1) + means = get_mean_event(c.x_test, c.y_test, 1, mask_value=mask_value) elif args.mode == "mean_bkg": - means = get_mean_event(c.x_test, c.y_test, 0) + means = get_mean_event(c.x_test, c.y_test, 0, mask_value=mask_value) + + print(means) if hasattr(c, "get_input_list"): input_transform = c.get_input_list