Skip to content
Snippets Groups Projects
Commit 480a7679 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

all neuron plot in plot_NN_2D script

parent dc93ef02
No related branches found
No related tags found
No related merge requests found
......@@ -120,11 +120,13 @@ def plot_NN_vs_var_2D(plotname, means,
def plot_NN_vs_var_2D_all(plotname, model, means,
var1_index, var1_range,
var2_index, var2_range,
varx_index,
vary_index,
nbinsx, xmin, xmax,
nbinsy, ymin, ymax,
transform_function=None,
var1_label=None,
var2_label=None,
varx_label=None,
vary_label=None,
zrange=None, logz=False,
plot_last_layer=False,
log_default_ymin=1e-5,
......@@ -132,15 +134,15 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
"Similar to plot_NN_vs_var_2D, but creates a grid of plots for all neurons."
var1_vals = np.arange(*var1_range)
var2_vals = np.arange(*var2_range)
varx_vals = np.linspace(xmin, xmax, nbinsx)
vary_vals = np.linspace(ymin, ymax, nbinsy)
# create the events for which we want to fetch the activations
events = np.tile(means, len(var1_vals)*len(var2_vals)).reshape(len(var2_vals), len(var1_vals), -1)
for i, y in enumerate(var2_vals):
for j, x in enumerate(var1_vals):
events[i][j][var1_index] = x
events[i][j][var2_index] = y
events = np.tile(means, len(varx_vals)*len(vary_vals)).reshape(len(vary_vals), len(varx_vals), -1)
for i, y in enumerate(vary_vals):
for j, x in enumerate(varx_vals):
events[i][j][varx_index] = x
events[i][j][vary_index] = y
# convert back into 1d array
events = events.reshape(-1, len(means))
......@@ -187,7 +189,7 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
for layer in range(layers):
for neuron in range(len(acts[layer][0])):
acts_neuron = acts[layer][:,neuron]
acts_neuron = acts_neuron.reshape(len(var2_vals), len(var1_vals))
acts_neuron = acts_neuron.reshape(len(vary_vals), len(varx_vals))
ax = grid_array[neuron][layer]
extra_opts = {}
if not (plot_last_layer and layer == layers-1):
......@@ -200,12 +202,12 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
else:
extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
im = ax.pcolormesh(var1_vals, var2_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
ax.set_facecolor("black")
if var1_label is not None:
ax.set_xlabel(var1_label)
if var2_label is not None:
ax.set_ylabel(var2_label)
if varx_label is not None:
ax.set_xlabel(varx_label)
if vary_label is not None:
ax.set_ylabel(vary_label)
ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
......@@ -342,6 +344,8 @@ if __name__ == "__main__":
def test_mean_signal():
c._load_data() # untransformed
mean_signal = get_mean_event(c.x_test, c.y_test, 1)
print("Mean signal: ")
......@@ -371,9 +375,11 @@ if __name__ == "__main__":
plot_NN_vs_var_2D_all("mt_vs_met_all.pdf", means=mean_signal,
model=c.model, transform_function=c.scaler.transform,
var1_index=c.fields.index("met"), var1_range=(0, 1000, 10),
var2_index=c.fields.index("mt"), var2_range=(0, 500, 10),
var1_label="met [GeV]", var2_label="mt [GeV]")
varx_index=c.fields.index("met"),
vary_index=c.fields.index("mt"),
nbinsx=100, xmin=0, xmax=1000,
nbinsy=100, ymin=0, ymax=500,
varx_label="met [GeV]", vary_label="mt [GeV]")
plot_NN_vs_var_2D("mt_vs_met_crosscheck.pdf", means=mean_signal,
scorefun=get_single_neuron_function(c.model, layer=3, neuron=0, scaler=c.scaler),
......
......@@ -7,13 +7,18 @@ logging.basicConfig()
import numpy as np
import ROOT
ROOT.gROOT.SetBatch()
ROOT.PyConfig.IgnoreCommandLineOptions = True
from KerasROOTClassification import ClassificationProject
from KerasROOTClassification.plotting import (
get_mean_event,
plot_NN_vs_var_2D,
plot_profile_2D,
plot_hist_2D_events,
plot_cond_avg_actmax_2D
plot_cond_avg_actmax_2D,
plot_NN_vs_var_2D_all,
)
from KerasROOTClassification.utils import get_single_neuron_function, get_max_activation_events
......@@ -27,6 +32,7 @@ parser.add_argument("-m", "--mode",
default="mean_sig")
parser.add_argument("-l", "--layer", type=int, help="Layer index (takes last layer by default)")
parser.add_argument("-n", "--neuron", type=int, default=0, help="Neuron index (takes first neuron by default)")
parser.add_argument("-a", "--all-neurons", action="store_true", help="Create a summary plot for all neurons in all hidden layers")
parser.add_argument("--log", action="store_true", help="Plot in color in log scale")
parser.add_argument("--contour", action="store_true", help="Interpolate with contours")
parser.add_argument("-b", "--nbins", default=20, type=int, help="Number of bins in x and y direction")
......@@ -42,6 +48,9 @@ parser.add_argument("-s", "--step-size", help="step size for activation maximisa
args = parser.parse_args()
if args.all_neurons and (not args.mode.startswith("mean")):
parser.error("--all-neurons currently only supported for mean_sig and mean_bkg")
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
......@@ -90,17 +99,31 @@ if args.mode.startswith("mean"):
elif args.mode == "mean_bkg":
means = get_mean_event(c.x_test, c.y_test, 0)
plot_NN_vs_var_2D(
args.output_filename,
means=means,
varx_index=varx_index,
vary_index=vary_index,
scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
varx_label=varx_label, vary_label=vary_label,
logscale=args.log, only_pixels=(not args.contour)
)
if not args.all_neurons:
plot_NN_vs_var_2D(
args.output_filename,
means=means,
varx_index=varx_index,
vary_index=vary_index,
scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
varx_label=varx_label, vary_label=vary_label,
logscale=args.log, only_pixels=(not args.contour)
)
else:
plot_NN_vs_var_2D_all(
args.output_filename,
means=means,
model=c.model,
transform_function=c.scaler.transform,
varx_index=varx_index,
vary_index=vary_index,
xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
logz=args.log,
plot_last_layer=False,
)
elif args.mode.startswith("profile"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment