From 7f450694b085014d394213ecc0ccf7aad33c3cc3 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Tue, 28 Aug 2018 17:25:17 +0200
Subject: [PATCH] starting to implement function for plotting actmax events for
 all neurons

---
 plotting.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 91 insertions(+)

diff --git a/plotting.py b/plotting.py
index cdcde3f..10d61e1 100644
--- a/plotting.py
+++ b/plotting.py
@@ -234,6 +234,97 @@ def plot_NN_vs_var_2D_all(plotname, model, means,
     save_show(plt, fig, plotname, bbox_inches='tight')
 
 
+def plot_profile_2D_all(plotname, model,
+                        events,
+                        nbinsx, xmin, xmax,
+                        nbinsy, ymin, ymax,
+                        transform_function=None,
+                        varx_label=None,
+                        vary_label=None,
+                        zrange=None, logz=False,
+                        plot_last_layer=False,
+                        log_default_ymin=1e-5,
+                        cmap="inferno", **kwargs):
+
+    "Similar to plot_NN_2D, but creates a grid of plots for all neurons."
+
+    # transform
+    if transform_function is not None:
+        events = transform_function(events)
+
+    acts = get_activations(model, events, print_shape_only=True)
+
+    if plot_last_layer:
+        n_neurons = [len(i[0]) for i in acts]
+    else:
+        n_neurons = [len(i[0]) for i in acts[:-1]]
+    layers = len(n_neurons)
+
+    nrows_ncols = (layers, max(n_neurons))
+    fig = plt.figure(1, figsize=nrows_ncols)
+    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0,
+                     label_mode="1",
+                     aspect=False,
+                     cbar_location="top",
+                     cbar_mode="single",
+                     cbar_pad=.2,
+                     cbar_size="5%",)
+    grid_array = np.array(grid)
+    grid_array = grid_array.reshape(*nrows_ncols[::-1])
+
+    # leave out the last layer
+    global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]])
+    global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]])
+
+    logger.info("global_min: {}".format(global_min))
+    logger.info("global_max: {}".format(global_max))
+
+    output_min_default = 0
+    output_max_default = 1
+
+    if global_min <= 0 and logz:
+        global_min = log_default_ymin
+        logger.info("Changing global_min to {}".format(log_default_ymin))
+
+    ims = []
+    for layer in range(layers):
+        for neuron in range(len(acts[layer][0])):
+            acts_neuron = acts[layer][:,neuron]
+            ax = grid_array[neuron][layer]
+            extra_opts = {}
+            if not (plot_last_layer and layer == layers-1):
+                # for hidden layers, plot the same z-scale
+                if logz:
+                    norm = matplotlib.colors.LogNorm
+                else:
+                    norm = matplotlib.colors.Normalize
+                if zrange is not None:
+                    extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
+                else:
+                    extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
+            #im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
+            hist, xedges, yedges = get_profile_2D(
+                valsx, valsy, acts_neuron,
+                nbinsx, xmin, xmax,
+                nbinsy, ymin, ymax,
+                **kwargs
+            )
+            X, Y = np.meshgrid(xedges, yedges)
+            im = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extra_opts)
+            ax.set_facecolor("black")
+            if varx_label is not None:
+                ax.set_xlabel(varx_label)
+            if vary_label is not None:
+                ax.set_ylabel(vary_label)
+            ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")
+
+    cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
+    cb.ax.xaxis.set_ticks_position('top')
+    cb.ax.xaxis.set_label_position('top')
+
+    save_show(plt, fig, plotname, bbox_inches='tight')
+
+
 def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"):
     X, Y = np.meshgrid(xedges, yedges)
 
-- 
GitLab