Skip to content
Snippets Groups Projects
Commit 8f889d68 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

adjust plot_profile_2D_all to flatten multidimensional layer outputs

parent f92546b5
No related branches found
No related tags found
No related merge requests found
...@@ -258,9 +258,9 @@ def plot_profile_2D_all(plotname, model, events, ...@@ -258,9 +258,9 @@ def plot_profile_2D_all(plotname, model, events,
logger.info("Done") logger.info("Done")
if plot_last_layer: if plot_last_layer:
n_neurons = [len(i[0]) for i in acts] n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts]
else: else:
n_neurons = [len(i[0]) for i in acts[:-1]] n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts[:-1]]
layers = len(n_neurons) layers = len(n_neurons)
nrows_ncols = (layers, max(n_neurons)) nrows_ncols = (layers, max(n_neurons))
...@@ -282,8 +282,10 @@ def plot_profile_2D_all(plotname, model, events, ...@@ -282,8 +282,10 @@ def plot_profile_2D_all(plotname, model, events,
ims = [] ims = []
reg_plots = [] reg_plots = []
for layer in range(layers): for layer in range(layers):
for neuron in range(len(acts[layer][0])): neurons_acts = acts[layer]
acts_neuron = acts[layer][:,neuron] neurons_acts = neurons_acts.reshape(neurons_acts.shape[0], -1)
for neuron in range(len(neurons_acts[0])):
acts_neuron = neurons_acts[:,neuron]
ax = grid_array[neuron][layer] ax = grid_array[neuron][layer]
extra_opts = {} extra_opts = {}
if not (plot_last_layer and layer == layers-1): if not (plot_last_layer and layer == layers-1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment