Skip to content
Snippets Groups Projects
Commit 5ce28def authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

weighted quantile for plot_NN_2D

parent 480a7679
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ from KerasROOTClassification.plotting import ( ...@@ -20,7 +20,7 @@ from KerasROOTClassification.plotting import (
plot_cond_avg_actmax_2D, plot_cond_avg_actmax_2D,
plot_NN_vs_var_2D_all, plot_NN_vs_var_2D_all,
) )
from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events, weighted_quantile
parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron') parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron')
parser.add_argument("project_dir") parser.add_argument("project_dir")
...@@ -73,8 +73,13 @@ else: ...@@ -73,8 +73,13 @@ else:
varx_label = args.varx varx_label = args.varx
vary_label = args.vary vary_label = args.vary
percentilesx = np.percentile(c.x_test[:,varx_index], [1,99]) # percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
percentilesy = np.percentile(c.x_test[:,vary_index], [1,99]) # percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)]
percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights)
percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights)
if args.xrange is not None: if args.xrange is not None:
if len(args.xrange) < 3: if len(args.xrange) < 3:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment