From c8003f9a8d34bc3a50b00abd1c9337d02a89a941 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Wed, 29 Aug 2018 09:51:42 +0200 Subject: [PATCH] plot_profile_2D_all working --- plotting.py | 61 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/plotting.py b/plotting.py index d1d1d28..101fad0 100644 --- a/plotting.py +++ b/plotting.py @@ -235,7 +235,7 @@ def plot_NN_vs_var_2D_all(plotname, model, means, def plot_profile_2D_all(plotname, model, events, - varx_index, vary_index, + valsx, valsy, nbinsx, xmin, xmax, nbinsy, ymin, ymax, transform_function=None, @@ -246,16 +246,15 @@ def plot_profile_2D_all(plotname, model, events, log_default_ymin=1e-5, cmap="inferno", **kwargs): - "Similar to plot_NN_2D, but creates a grid of plots for all neurons." - - valsx = np.array(events[:,varx_index]) - valsy = np.array(events[:,vary_index]) + "Similar to plot_profile_2D, but creates a grid of plots for all neurons." # transform if transform_function is not None: events = transform_function(events) + logger.info("Reading activations for all neurons") acts = get_activations(model, events, print_shape_only=True) + logger.info("Done") if plot_last_layer: n_neurons = [len(i[0]) for i in acts] @@ -275,21 +274,12 @@ def plot_profile_2D_all(plotname, model, events, grid_array = np.array(grid) grid_array = grid_array.reshape(*nrows_ncols[::-1]) - # leave out the last layer - global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]]) - global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]]) - - logger.info("global_min: {}".format(global_min)) - logger.info("global_max: {}".format(global_max)) - - output_min_default = 0 - output_max_default = 1 - - if global_min <= 0 and logz: - global_min = log_default_ymin - logger.info("Changing global_min to {}".format(log_default_ymin)) + global_min = None + global_max = None + logger.info("Creating profile histograms") ims = [] + reg_plots = [] for layer in range(layers): for neuron in range(len(acts[layer][0])): acts_neuron = acts[layer][:,neuron] @@ -305,27 +295,46 @@ def plot_profile_2D_all(plotname, model, events, extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1]) else: extra_opts["norm"] = norm(vmin=global_min, vmax=global_max) - #im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts) hist, xedges, yedges = get_profile_2D( valsx, valsy, acts_neuron, nbinsx, xmin, xmax, nbinsy, ymin, ymax, **kwargs ) + if global_min is None or hist.min() < global_min: + global_min = hist.min() + if global_max is None or hist.max() > global_max: + global_max = hist.max() X, Y = np.meshgrid(xedges, yedges) - im = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extra_opts) - ax.set_facecolor("black") - if varx_label is not None: - ax.set_xlabel(varx_label) - if vary_label is not None: - ax.set_ylabel(vary_label) - ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white") + reg_plots.append((layer, neuron, ax, (X, Y, hist), dict(cmap="inferno", linewidth=0, rasterized=True, **extra_opts))) + logger.info("Done") + + logger.info("global_min: {}".format(global_min)) + logger.info("global_max: {}".format(global_max)) + + if global_min <= 0 and logz: + global_min = log_default_ymin + logger.info("Changing global_min to {}".format(log_default_ymin)) + + for layer, neuron, ax, args, kwargs in reg_plots: + if zrange is None: + kwargs["norm"].vmin = global_min + kwargs["norm"].vmax = global_max + im = ax.pcolormesh(*args, **kwargs) + ax.set_facecolor("black") + if varx_label is not None: + ax.set_xlabel(varx_label) + if vary_label is not None: + ax.set_ylabel(vary_label) + ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white") cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal") cb.ax.xaxis.set_ticks_position('top') cb.ax.xaxis.set_label_position('top') + logger.info("Rendering") save_show(plt, fig, plotname, bbox_inches='tight') + logger.info("Done") def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"): -- GitLab