diff --git a/plotting.py b/plotting.py index cdcde3f5e763874e4e0905a47591347f2b3dc612..10d61e16b3adc90cb9f81a24c07dfbf34e7096ac 100644 --- a/plotting.py +++ b/plotting.py @@ -234,6 +234,97 @@ def plot_NN_vs_var_2D_all(plotname, model, means, save_show(plt, fig, plotname, bbox_inches='tight') +def plot_profile_2D_all(plotname, model, + events, + nbinsx, xmin, xmax, + nbinsy, ymin, ymax, + transform_function=None, + varx_label=None, + vary_label=None, + zrange=None, logz=False, + plot_last_layer=False, + log_default_ymin=1e-5, + cmap="inferno", **kwargs): + + "Similar to plot_NN_2D, but creates a grid of plots for all neurons." + + # transform + if transform_function is not None: + events = transform_function(events) + + acts = get_activations(model, events, print_shape_only=True) + + if plot_last_layer: + n_neurons = [len(i[0]) for i in acts] + else: + n_neurons = [len(i[0]) for i in acts[:-1]] + layers = len(n_neurons) + + nrows_ncols = (layers, max(n_neurons)) + fig = plt.figure(1, figsize=nrows_ncols) + grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0, + label_mode="1", + aspect=False, + cbar_location="top", + cbar_mode="single", + cbar_pad=.2, + cbar_size="5%",) + grid_array = np.array(grid) + grid_array = grid_array.reshape(*nrows_ncols[::-1]) + + # leave out the last layer + global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]]) + global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]]) + + logger.info("global_min: {}".format(global_min)) + logger.info("global_max: {}".format(global_max)) + + output_min_default = 0 + output_max_default = 1 + + if global_min <= 0 and logz: + global_min = log_default_ymin + logger.info("Changing global_min to {}".format(log_default_ymin)) + + ims = [] + for layer in range(layers): + for neuron in range(len(acts[layer][0])): + acts_neuron = acts[layer][:,neuron] + ax = grid_array[neuron][layer] + extra_opts = {} + if not (plot_last_layer and layer == layers-1): + # for hidden layers, plot the same z-scale + if logz: + norm = matplotlib.colors.LogNorm + else: + norm = matplotlib.colors.Normalize + if zrange is not None: + extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1]) + else: + extra_opts["norm"] = norm(vmin=global_min, vmax=global_max) + #im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts) + hist, xedges, yedges = get_profile_2D( + valsx, valsy, acts_neuron, + nbinsx, xmin, xmax, + nbinsy, ymin, ymax, + **kwargs + ) + X, Y = np.meshgrid(xedges, yedges) + im = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extra_opts) + ax.set_facecolor("black") + if varx_label is not None: + ax.set_xlabel(varx_label) + if vary_label is not None: + ax.set_ylabel(vary_label) + ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white") + + cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal") + cb.ax.xaxis.set_ticks_position('top') + cb.ax.xaxis.set_label_position('top') + + save_show(plt, fig, plotname, bbox_inches='tight') + + def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"): X, Y = np.meshgrid(xedges, yedges)