Skip to content
Snippets Groups Projects
Commit 5ce28def authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

weighted quantile for plot_NN_2D

parent 480a7679
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ from KerasROOTClassification.plotting import (
plot_cond_avg_actmax_2D,
plot_NN_vs_var_2D_all,
)
from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events
from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events, weighted_quantile
parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron')
parser.add_argument("project_dir")
......@@ -73,8 +73,13 @@ else:
varx_label = args.varx
vary_label = args.vary
percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
# percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
# percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)]
percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights)
percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights)
if args.xrange is not None:
if len(args.xrange) < 3:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment