Skip to content
Snippets Groups Projects
plot_single_neuron.py 2.66 KiB
Newer Older
#!/usr/bin/env python

import sys
import argparse

import numpy as np

from KerasROOTClassification import ClassificationProject
from KerasROOTClassification.plotting import get_mean_event, plot_NN_vs_var_2D
from KerasROOTClassification.tfhelpers import get_single_neuron_function

parser = argparse.ArgumentParser(description='Create various 2D plots for a single neuron')
parser.add_argument("project_dir")
parser.add_argument("output_filename")
parser.add_argument("varx")
parser.add_argument("vary")
parser.add_argument("-m", "--mode", choices=["mean_sig", "mean_bkg"], default="mean_sig")
parser.add_argument("-l", "--layer", type=int, help="Layer index (takes last layer by default)")
parser.add_argument("-n", "--neuron", type=int, default=0, help="Neuron index (takes first neuron by default)")
parser.add_argument("--log", action="store_true", help="Plot in color in log scale")
parser.add_argument("--contour", action="store_true", help="Interpolate with contours")
parser.add_argument("-b", "--nbins", default=20, type=int, help="Number of bins in x and y direction")
parser.add_argument("-x", "--xrange", type=float, nargs="+", help="xrange (low, high)")
parser.add_argument("-y", "--yrange", type=float, nargs="+", help="yrange (low, high)")
args = parser.parse_args()

c = ClassificationProject(args.project_dir)

layer = args.layer
neuron = args.neuron

if layer is None:
    layer = c.layers

varx_index = c.branches.index(args.varx)
vary_index = c.branches.index(args.vary)

x_test = c.x_test

percentilesx = np.percentile(x_test[:,varx_index], [1,99])
percentilesy = np.percentile(x_test[:,vary_index], [1,99])

if args.xrange is not None:
    if len(args.xrange) < 3:
        args.xrange.append(args.nbins)
    varx_range = args.xrange
else:
    varx_range = (percentilesx[0], percentilesx[1], args.nbins)

if args.yrange is not None:
    if len(args.yrange) < 3:
        args.yrange.append(args.nbins)
    vary_range = args.yrange
else:
    vary_range = (percentilesy[0], percentilesy[1], args.nbins)

if args.mode == "mean_sig":
    means = get_mean_event(c.x_test, c.y_test, 1)
elif args.mode == "mean_bkg":
    means = get_mean_event(c.x_test, c.y_test, 0)

plot_NN_vs_var_2D(args.output_filename,
                  means=means,
                  varx_index=varx_index,
                  vary_index=vary_index,
                  scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
                  xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
                  ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
                  varx_label=args.varx, vary_label=args.vary,
                  logscale=args.log, only_pixels=(not args.contour))