Skip to content
Snippets Groups Projects
Commit c8003f9a authored by Nikolai's avatar Nikolai
Browse files

plot_profile_2D_all working

parent 1c275a61
Branches dev-actmax
No related tags found
No related merge requests found
...@@ -235,7 +235,7 @@ def plot_NN_vs_var_2D_all(plotname, model, means, ...@@ -235,7 +235,7 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
def plot_profile_2D_all(plotname, model, events, def plot_profile_2D_all(plotname, model, events,
varx_index, vary_index, valsx, valsy,
nbinsx, xmin, xmax, nbinsx, xmin, xmax,
nbinsy, ymin, ymax, nbinsy, ymin, ymax,
transform_function=None, transform_function=None,
...@@ -246,16 +246,15 @@ def plot_profile_2D_all(plotname, model, events, ...@@ -246,16 +246,15 @@ def plot_profile_2D_all(plotname, model, events,
log_default_ymin=1e-5, log_default_ymin=1e-5,
cmap="inferno", **kwargs): cmap="inferno", **kwargs):
"Similar to plot_NN_2D, but creates a grid of plots for all neurons." "Similar to plot_profile_2D, but creates a grid of plots for all neurons."
valsx = np.array(events[:,varx_index])
valsy = np.array(events[:,vary_index])
# transform # transform
if transform_function is not None: if transform_function is not None:
events = transform_function(events) events = transform_function(events)
logger.info("Reading activations for all neurons")
acts = get_activations(model, events, print_shape_only=True) acts = get_activations(model, events, print_shape_only=True)
logger.info("Done")
if plot_last_layer: if plot_last_layer:
n_neurons = [len(i[0]) for i in acts] n_neurons = [len(i[0]) for i in acts]
...@@ -275,21 +274,12 @@ def plot_profile_2D_all(plotname, model, events, ...@@ -275,21 +274,12 @@ def plot_profile_2D_all(plotname, model, events,
grid_array = np.array(grid) grid_array = np.array(grid)
grid_array = grid_array.reshape(*nrows_ncols[::-1]) grid_array = grid_array.reshape(*nrows_ncols[::-1])
# leave out the last layer global_min = None
global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]]) global_max = None
global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]])
logger.info("global_min: {}".format(global_min))
logger.info("global_max: {}".format(global_max))
output_min_default = 0
output_max_default = 1
if global_min <= 0 and logz:
global_min = log_default_ymin
logger.info("Changing global_min to {}".format(log_default_ymin))
logger.info("Creating profile histograms")
ims = [] ims = []
reg_plots = []
for layer in range(layers): for layer in range(layers):
for neuron in range(len(acts[layer][0])): for neuron in range(len(acts[layer][0])):
acts_neuron = acts[layer][:,neuron] acts_neuron = acts[layer][:,neuron]
...@@ -305,27 +295,46 @@ def plot_profile_2D_all(plotname, model, events, ...@@ -305,27 +295,46 @@ def plot_profile_2D_all(plotname, model, events,
extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1]) extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
else: else:
extra_opts["norm"] = norm(vmin=global_min, vmax=global_max) extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
#im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
hist, xedges, yedges = get_profile_2D( hist, xedges, yedges = get_profile_2D(
valsx, valsy, acts_neuron, valsx, valsy, acts_neuron,
nbinsx, xmin, xmax, nbinsx, xmin, xmax,
nbinsy, ymin, ymax, nbinsy, ymin, ymax,
**kwargs **kwargs
) )
if global_min is None or hist.min() < global_min:
global_min = hist.min()
if global_max is None or hist.max() > global_max:
global_max = hist.max()
X, Y = np.meshgrid(xedges, yedges) X, Y = np.meshgrid(xedges, yedges)
im = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extra_opts) reg_plots.append((layer, neuron, ax, (X, Y, hist), dict(cmap="inferno", linewidth=0, rasterized=True, **extra_opts)))
ax.set_facecolor("black") logger.info("Done")
if varx_label is not None:
ax.set_xlabel(varx_label) logger.info("global_min: {}".format(global_min))
if vary_label is not None: logger.info("global_max: {}".format(global_max))
ax.set_ylabel(vary_label)
ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white") if global_min <= 0 and logz:
global_min = log_default_ymin
logger.info("Changing global_min to {}".format(log_default_ymin))
for layer, neuron, ax, args, kwargs in reg_plots:
if zrange is None:
kwargs["norm"].vmin = global_min
kwargs["norm"].vmax = global_max
im = ax.pcolormesh(*args, **kwargs)
ax.set_facecolor("black")
if varx_label is not None:
ax.set_xlabel(varx_label)
if vary_label is not None:
ax.set_ylabel(vary_label)
ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal") cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
cb.ax.xaxis.set_ticks_position('top') cb.ax.xaxis.set_ticks_position('top')
cb.ax.xaxis.set_label_position('top') cb.ax.xaxis.set_label_position('top')
logger.info("Rendering")
save_show(plt, fig, plotname, bbox_inches='tight') save_show(plt, fig, plotname, bbox_inches='tight')
logger.info("Done")
def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"): def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment