diff --git a/plotting.py b/plotting.py index 6d5f3c4e9ac2e06dde334b70c78ed8c2b4c499de..d0d73155927dbcbaf4668d51b58542b3ded0b5fb 100644 --- a/plotting.py +++ b/plotting.py @@ -258,9 +258,9 @@ def plot_profile_2D_all(plotname, model, events, logger.info("Done") if plot_last_layer: - n_neurons = [len(i[0]) for i in acts] + n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts] else: - n_neurons = [len(i[0]) for i in acts[:-1]] + n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts[:-1]] layers = len(n_neurons) nrows_ncols = (layers, max(n_neurons)) @@ -282,8 +282,10 @@ def plot_profile_2D_all(plotname, model, events, ims = [] reg_plots = [] for layer in range(layers): - for neuron in range(len(acts[layer][0])): - acts_neuron = acts[layer][:,neuron] + neurons_acts = acts[layer] + neurons_acts = neurons_acts.reshape(neurons_acts.shape[0], -1) + for neuron in range(len(neurons_acts[0])): + acts_neuron = neurons_acts[:,neuron] ax = grid_array[neuron][layer] extra_opts = {} if not (plot_last_layer and layer == layers-1):