Newer
Older
#!/usr/bin/env python
import sys
import argparse
import numpy as np
from KerasROOTClassification import ClassificationProject
from KerasROOTClassification.plotting import (
get_mean_event,
plot_NN_vs_var_2D,
plot_profile_2D,
plot_hist_2D_events
from KerasROOTClassification.tfhelpers import get_single_neuron_function
parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron')
parser.add_argument("project_dir")
parser.add_argument("output_filename")
parser.add_argument("varx")
parser.add_argument("vary")
parser.add_argument("-m", "--mode",
choices=["mean_sig", "mean_bkg", "profile_sig", "profile_bkg", "hist_sig", "hist_bkg"],
default="mean_sig")
parser.add_argument("-l", "--layer", type=int, help="Layer index (takes last layer by default)")
parser.add_argument("-n", "--neuron", type=int, default=0, help="Neuron index (takes first neuron by default)")
parser.add_argument("--log", action="store_true", help="Plot in color in log scale")
parser.add_argument("--contour", action="store_true", help="Interpolate with contours")
parser.add_argument("-b", "--nbins", default=20, type=int, help="Number of bins in x and y direction")
parser.add_argument("-x", "--xrange", type=float, nargs="+", help="xrange (low, high)")
parser.add_argument("-y", "--yrange", type=float, nargs="+", help="yrange (low, high)")
parser.add_argument("-p", "--profile-metric", help="metric for profile modes", default="mean", choices=["mean", "average", "max"])
args = parser.parse_args()
c = ClassificationProject(args.project_dir)
layer = args.layer
neuron = args.neuron
if layer is None:
layer = c.layers
varx_index = c.branches.index(args.varx)
vary_index = c.branches.index(args.vary)
varx_label = args.varx
vary_label = args.vary
percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
if args.xrange is not None:
if len(args.xrange) < 3:
args.xrange.append(args.nbins)
varx_range = args.xrange
else:
varx_range = (percentilesx[0], percentilesx[1], args.nbins)
if args.yrange is not None:
if len(args.yrange) < 3:
args.yrange.append(args.nbins)
vary_range = args.yrange
else:
vary_range = (percentilesy[0], percentilesy[1], args.nbins)
if args.mode.startswith("mean"):
if args.mode == "mean_sig":
means = get_mean_event(c.x_test, c.y_test, 1)
elif args.mode == "mean_bkg":
means = get_mean_event(c.x_test, c.y_test, 0)
plot_NN_vs_var_2D(
args.output_filename,
means=means,
varx_index=varx_index,
vary_index=vary_index,
scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
varx_label=varx_label, vary_label=vary_label,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
logscale=args.log, only_pixels=(not args.contour)
)
elif args.mode.startswith("profile"):
metric_dict = {
"mean" : np.mean,
"max" : np.max,
"average" : np.average,
}
if args.mode == "profile_sig":
class_index = 1
else:
class_index = 0
valsx = c.x_test[c.y_test==class_index][:,varx_index]
valsy = c.x_test[c.y_test==class_index][:,vary_index]
scores = c.scores_test[c.y_test==class_index].reshape(-1)
opt_kwargs = dict()
if args.profile_metric == "average":
opt_kwargs["weights"] = c.w_test[c.y_test==class_index]
plot_profile_2D(
args.output_filename,
valsx, valsy, scores,
xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
metric=metric_dict[args.profile_metric],
varx_label=varx_label, vary_label=vary_label,
**opt_kwargs
)
elif args.mode.startswith("hist"):
if args.mode == "hist_sig":
class_index = 1
else:
class_index = 0
valsx = c.x_test[c.y_test==class_index][:,varx_index]
valsy = c.x_test[c.y_test==class_index][:,vary_index]
weights = c.w_test[c.y_test==class_index]
plot_hist_2D_events(
args.output_filename,
valsx, valsy,
xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
weights=weights,
varx_label=varx_label, vary_label=vary_label,
)