Skip to content
Snippets Groups Projects
Commit 7f450694 authored by Nikolai's avatar Nikolai
Browse files

starting to implement function for plotting actmax events for all neurons

parent 887f1610
No related branches found
No related tags found
No related merge requests found
......@@ -234,6 +234,97 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
save_show(plt, fig, plotname, bbox_inches='tight')
def plot_profile_2D_all(plotname, model,
events,
nbinsx, xmin, xmax,
nbinsy, ymin, ymax,
transform_function=None,
varx_label=None,
vary_label=None,
zrange=None, logz=False,
plot_last_layer=False,
log_default_ymin=1e-5,
cmap="inferno", **kwargs):
"Similar to plot_NN_2D, but creates a grid of plots for all neurons."
# transform
if transform_function is not None:
events = transform_function(events)
acts = get_activations(model, events, print_shape_only=True)
if plot_last_layer:
n_neurons = [len(i[0]) for i in acts]
else:
n_neurons = [len(i[0]) for i in acts[:-1]]
layers = len(n_neurons)
nrows_ncols = (layers, max(n_neurons))
fig = plt.figure(1, figsize=nrows_ncols)
grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0,
label_mode="1",
aspect=False,
cbar_location="top",
cbar_mode="single",
cbar_pad=.2,
cbar_size="5%",)
grid_array = np.array(grid)
grid_array = grid_array.reshape(*nrows_ncols[::-1])
# leave out the last layer
global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]])
global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]])
logger.info("global_min: {}".format(global_min))
logger.info("global_max: {}".format(global_max))
output_min_default = 0
output_max_default = 1
if global_min <= 0 and logz:
global_min = log_default_ymin
logger.info("Changing global_min to {}".format(log_default_ymin))
ims = []
for layer in range(layers):
for neuron in range(len(acts[layer][0])):
acts_neuron = acts[layer][:,neuron]
ax = grid_array[neuron][layer]
extra_opts = {}
if not (plot_last_layer and layer == layers-1):
# for hidden layers, plot the same z-scale
if logz:
norm = matplotlib.colors.LogNorm
else:
norm = matplotlib.colors.Normalize
if zrange is not None:
extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
else:
extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
#im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
hist, xedges, yedges = get_profile_2D(
valsx, valsy, acts_neuron,
nbinsx, xmin, xmax,
nbinsy, ymin, ymax,
**kwargs
)
X, Y = np.meshgrid(xedges, yedges)
im = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extra_opts)
ax.set_facecolor("black")
if varx_label is not None:
ax.set_xlabel(varx_label)
if vary_label is not None:
ax.set_ylabel(vary_label)
ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
cb.ax.xaxis.set_ticks_position('top')
cb.ax.xaxis.set_label_position('top')
save_show(plt, fig, plotname, bbox_inches='tight')
def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"):
X, Y = np.meshgrid(xedges, yedges)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment