Skip to content
Snippets Groups Projects
plotting.py 23.47 KiB
#!/usr/bin/env python

import os
import math

import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib.ticker import LogFormatter
from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
import numpy as np

from .keras_visualize_activations.read_activations import get_activations
from .utils import get_grad_function, max_activation_wrt_input, create_random_event

import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

"""
Some further plotting functions
"""

def save_show(plt, fig, filename, **kwargs):
    "Save a figure and show it in case we are in ipython or jupyter notebook."
    fig.savefig(filename, **kwargs)
    try:
        get_ipython
        plt.show()
        return fig
    except NameError:
        plt.close(fig)
        return None


def get_mean_event(x, y, class_label, mask_value=None):
    means = []
    for var_index in range(x.shape[1]):
        if mask_value is not None:
            masked_values = np.where(x[:,var_index] != mask_value)[0]
            x = x[masked_values]
            y = y[masked_values]
        means.append(np.mean(x[y==class_label][:,var_index]))
    return means


def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None):
    "Plot the NN output vs one variable with the other variables set to the given mean values"

    # example: vary var1
    logger.info("Creating varied events (1d)")
    sequence = np.arange(*var_range)
    events = np.tile(means, len(sequence)).reshape(-1, len(means))
    events[:,var_index] = sequence

    logger.info("Predicting scores")
    scores = scorefun(events)

    fig, ax = plt.subplots()
    ax.plot(sequence, scores)
    if var_label is not None:
        ax.set_xlabel(var_label)
    ax.set_ylabel("NN output")
    save_show(plt, fig, plotname)


def plot_NN_vs_var_2D(plotname, means,
                      scorefun,
                      varx_index,
                      vary_index,
                      nbinsx, xmin, xmax,
                      nbinsy, ymin, ymax,
                      varx_label=None,
                      vary_label=None,
                      logscale=False,
                      ncontours=20,
                      only_pixels=False,
                      black_contourlines=False,
                      cmap="inferno"):

    logger.info("Creating varied events (2d)")

    sequence1 = np.linspace(xmin, xmax, nbinsx)
    sequence2 = np.linspace(ymin, ymax, nbinsy)

    # the following is a 2d array of events (so effectively 3D)
    events = np.tile(means, len(sequence1)*len(sequence2)).reshape(len(sequence2), len(sequence1), -1)

    # fill in the varied values
    # (probably there is a more clever way, but sufficient here)
    for i, y in enumerate(sequence2):
        for j, x in enumerate(sequence1):
            events[i][j][varx_index] = x
            events[i][j][vary_index] = y

    # convert back into 1d array
    events = events.reshape(-1, len(means))

    logger.info("Predicting scores")
    scores = scorefun(events)

    # convert scores into 2d array
    scores = scores.reshape(len(sequence2), len(sequence1))

    fig, ax = plt.subplots()

    zmin = np.min(scores)
    zmax = np.max(scores)

    if logscale:
        if zmin <= 0:
            zmin = 1e-5
            logger.info("Setting zmin to {}".format(zmin))
        lvls = np.logspace(math.log10(zmin), math.log10(zmax), ncontours)
        if only_pixels:
            pcm = ax.pcolormesh(sequence1, sequence2, scores, norm=matplotlib.colors.LogNorm(vmin=zmin, vmax=zmax), cmap=cmap, linewidth=0, rasterized=True)
        else:
            pcm = ax.contourf(sequence1, sequence2, scores, levels=lvls, norm=matplotlib.colors.LogNorm(vmin=zmin, vmax=zmax), cmap=cmap)
        if black_contourlines:
            ax.contour(sequence1, sequence2, scores, levels=lvls, colors="k", linewidths=1)
        l_f = LogFormatter(10, labelOnlyBase=False, minor_thresholds=(np.inf, np.inf))
        cbar = fig.colorbar(pcm, ax=ax, extend='max', ticks=lvls, format=l_f)
    else:
        if only_pixels:
            pcm = ax.pcolormesh(sequence1, sequence2, scores, norm=matplotlib.colors.Normalize(vmin=zmin, vmax=zmax), cmap=cmap, linewidth=0, rasterized=True)
        else:
            pcm = ax.contourf(sequence1, sequence2, scores, ncontours, norm=matplotlib.colors.Normalize(vmin=zmin, vmax=zmax), cmap=cmap)
        if black_contourlines:
            ax.contour(sequence1, sequence2, scores, ncontours, colors="k", linewidths=1)
        cbar = fig.colorbar(pcm, ax=ax, extend='max')

    cbar.set_label("NN output")
    if varx_label is not None:
        ax.set_xlabel(varx_label)
    if vary_label is not None:
        ax.set_ylabel(vary_label)
    save_show(plt, fig, plotname)


def plot_NN_vs_var_2D_all(plotname, model, means,
                          varx_index,
                          vary_index,
                          nbinsx, xmin, xmax,
                          nbinsy, ymin, ymax,
                          transform_function=None,
                          varx_label=None,
                          vary_label=None,
                          zrange=None, logz=False,
                          plot_last_layer=False,
                          log_default_ymin=1e-5,
                          cmap="inferno"):

    "Similar to plot_NN_vs_var_2D, but creates a grid of plots for all neurons."

    varx_vals = np.linspace(xmin, xmax, nbinsx)
    vary_vals = np.linspace(ymin, ymax, nbinsy)

    # create the events for which we want to fetch the activations
    events = np.tile(means, len(varx_vals)*len(vary_vals)).reshape(len(vary_vals), len(varx_vals), -1)
    for i, y in enumerate(vary_vals):
        for j, x in enumerate(varx_vals):
            events[i][j][varx_index] = x
            events[i][j][vary_index] = y

    # convert back into 1d array
    events = events.reshape(-1, len(means))

    # transform
    if transform_function is not None:
        events = transform_function(events)

    acts = get_activations(model, events, print_shape_only=True)

    if plot_last_layer:
        n_neurons = [len(i[0]) for i in acts]
    else:
        n_neurons = [len(i[0]) for i in acts[:-1]]
    layers = len(n_neurons)

    nrows_ncols = (layers, max(n_neurons))
    fig = plt.figure(1, figsize=nrows_ncols)
    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0,
                     label_mode="1",
                     aspect=False,
                     cbar_location="top",
                     cbar_mode="single",
                     cbar_pad=.2,
                     cbar_size="5%",)
    grid_array = np.array(grid)
    grid_array = grid_array.reshape(*nrows_ncols[::-1])

    # leave out the last layer
    global_min = min([np.min(ar_layer) for ar_layer in acts[:-1]])
    global_max = max([np.max(ar_layer) for ar_layer in acts[:-1]])

    logger.info("global_min: {}".format(global_min))
    logger.info("global_max: {}".format(global_max))

    output_min_default = 0
    output_max_default = 1

    if global_min <= 0 and logz:
        global_min = log_default_ymin
        logger.info("Changing global_min to {}".format(log_default_ymin))

    ims = []
    for layer in range(layers):
        for neuron in range(len(acts[layer][0])):
            acts_neuron = acts[layer][:,neuron]
            acts_neuron = acts_neuron.reshape(len(vary_vals), len(varx_vals))
            ax = grid_array[neuron][layer]
            extra_opts = {}
            if not (plot_last_layer and layer == layers-1):
                # for hidden layers, plot the same z-scale
                if logz:
                    norm = matplotlib.colors.LogNorm
                else:
                    norm = matplotlib.colors.Normalize
                if zrange is not None:
                    extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
                else:
                    extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
            im = ax.pcolormesh(varx_vals, vary_vals, acts_neuron, cmap=cmap, linewidth=0, rasterized=True, **extra_opts)
            ax.set_facecolor("black")
            if varx_label is not None:
                ax.set_xlabel(varx_label)
            if vary_label is not None:
                ax.set_ylabel(vary_label)
            ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")

    cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
    cb.ax.xaxis.set_ticks_position('top')
    cb.ax.xaxis.set_label_position('top')

    save_show(plt, fig, plotname, bbox_inches='tight')


def plot_profile_2D_all(plotname, model, events,
                        valsx, valsy,
                        nbinsx, xmin, xmax,
                        nbinsy, ymin, ymax,
                        transform_function=None,
                        varx_label=None,
                        vary_label=None,
                        zrange=None, logz=False,
                        plot_last_layer=False,
                        log_default_ymin=1e-5,
                        global_norm=True,
                        cmap="inferno", **kwargs):

    "Similar to plot_profile_2D, but creates a grid of plots for all neurons."

    # transform
    if transform_function is not None:
        events = transform_function(events)

    logger.info("Reading activations for all neurons")
    acts = get_activations(model, events, print_shape_only=True)
    logger.info("Done")

    if plot_last_layer:
        n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts]
    else:
        n_neurons = [len(i.reshape(i.shape[0], -1)[0]) for i in acts[:-1]]
    layers = len(n_neurons)

    nrows_ncols = (layers, max(n_neurons))
    fig = plt.figure(1, figsize=nrows_ncols)
    grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols[::-1], axes_pad=0,
                     label_mode="1",
                     aspect=False,
                     cbar_location="top",
                     cbar_mode="single",
                     cbar_pad=.2,
                     cbar_size="5%",)
    grid_array = np.array(grid)
    grid_array = grid_array.reshape(*nrows_ncols[::-1])

    global_min = None
    global_max = None

    logger.info("Creating profile histograms")
    ims = []
    reg_plots = []
    for layer in range(layers):
        neurons_acts = acts[layer]
        neurons_acts = neurons_acts.reshape(neurons_acts.shape[0], -1)
        for neuron in range(len(neurons_acts[0])):
            acts_neuron = neurons_acts[:,neuron]
            ax = grid_array[neuron][layer]
            extra_opts = {}
            if not (plot_last_layer and layer == layers-1):
                # for hidden layers, plot the same z-scale
                if logz:
                    norm = matplotlib.colors.LogNorm
                else:
                    norm = matplotlib.colors.Normalize
                if zrange is not None:
                    extra_opts["norm"] = norm(vmin=zrange[0], vmax=zrange[1])
                else:
                    extra_opts["norm"] = norm(vmin=global_min, vmax=global_max)
            hist, xedges, yedges = get_profile_2D(
                valsx, valsy, acts_neuron,
                nbinsx, xmin, xmax,
                nbinsy, ymin, ymax,
                **kwargs
            )
            if global_min is None or hist.min() < global_min:
                global_min = hist.min()
            if global_max is None or hist.max() > global_max:
                global_max = hist.max()
            X, Y = np.meshgrid(xedges, yedges)
            reg_plots.append((layer, neuron, ax, (X, Y, hist), dict(cmap="inferno", linewidth=0, rasterized=True, **extra_opts)))
    logger.info("Done")

    logger.info("global_min: {}".format(global_min))
    logger.info("global_max: {}".format(global_max))

    if global_min <= 0 and logz:
        global_min = log_default_ymin
        logger.info("Changing global_min to {}".format(log_default_ymin))

    for layer, neuron, ax, args, kwargs in reg_plots:
        if zrange is None:
            kwargs["norm"].vmin = global_min
            kwargs["norm"].vmax = global_max
        if not global_norm:
            kwargs["norm"] = None
        im = ax.pcolormesh(*args, **kwargs)
        ax.set_facecolor("black")
        if varx_label is not None:
            ax.set_xlabel(varx_label)
        if vary_label is not None:
            ax.set_ylabel(vary_label)
        ax.text(0., 0.5, "{}, {}".format(layer, neuron), transform=ax.transAxes, color="white")

    cb = fig.colorbar(im, cax=grid[0].cax, orientation="horizontal")
    cb.ax.xaxis.set_ticks_position('top')
    cb.ax.xaxis.set_label_position('top')

    logger.info("Rendering")
    save_show(plt, fig, plotname, bbox_inches='tight')
    logger.info("Done")


def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=None, log=False, zlabel="# of events"):
    X, Y = np.meshgrid(xedges, yedges)

    fig, ax = plt.subplots()

    extraopts = dict()
    if log:
        extraopts.update(norm=matplotlib.colors.LogNorm(vmin=np.min(hist[hist>0]), vmax=np.max(hist)))

    ax.set_facecolor("black")
    pcm = ax.pcolormesh(X, Y, hist, cmap="inferno", linewidth=0, rasterized=True, **extraopts)
    cbar = fig.colorbar(pcm, ax=ax)
    cbar.set_label(zlabel)
    ax.set_ylabel(vary_label)
    ax.set_xlabel(varx_label)
    save_show(plt, fig, plotname)


def plot_hist_2D_events(plotname, valsx, valsy, nbinsx, xmin, xmax, nbinsy, ymin, ymax,
                        weights=None,
                        varx_label=None, vary_label=None, log=False):

    xedges = np.linspace(xmin, xmax, nbinsx)
    yedges = np.linspace(ymin, ymax, nbinsy)

    hist, xedges, yedges = np.histogram2d(valsx, valsy, bins=(xedges, yedges), weights=weights)

    hist = hist.T

    plot_hist_2D(plotname, xedges, yedges, hist, varx_label, vary_label, log)


def plot_cond_avg_actmax_2D(plotname, model, layer, neuron, ranges,
                            varx_index, vary_index,
                            nbinsx, xmin, xmax, nbinsy, ymin, ymax,
                            transform=None, inverse_transform=None,
                            ntries=20,
                            step=1,
                            maxit=1,
                            **kwargs):

    transform_given = [fn is not None for fn in [transform, inverse_transform]]
    if any(transform_given) and not all(transform_given):
        raise ValueError("Need to pass both transform and inverse_transform if data should be transformed")

    xedges = np.linspace(xmin, xmax, nbinsx)
    yedges = np.linspace(ymin, ymax, nbinsy)

    hist = np.zeros(int(nbinsx*nbinsy)).reshape(int(nbinsx), int(nbinsy))

    gradient_function = get_grad_function(model, layer, neuron)

    for ix, x in enumerate(xedges):
        for iy, y in enumerate(yedges):
            random_event = create_random_event(ranges)
            if inverse_transform is not None:
                random_event = inverse_transform(random_event)
            for index, val in [(varx_index, x), (vary_index, y)]:
                random_event[0][index] = val
            if transform is not None:
                random_event = transform(random_event)
            act = np.mean([max_activation_wrt_input(gradient_function, random_event, maxit=maxit, step=step, const_indices=[varx_index, vary_index])[0][0] for i in range(ntries)])
            hist[ix][iy] = act

    hist = hist.T

    plot_hist_2D(plotname, xedges, yedges, hist, zlabel="Neuron output", **kwargs)


def get_profile_2D(valsx, valsy, scores,
                   nbinsx, xmin, xmax,
                   nbinsy, ymin, ymax,
                   metric=np.mean, weights=None):
    xedges = np.linspace(xmin, xmax, nbinsx)
    yedges = np.linspace(ymin, ymax, nbinsy)

    binindices_x = np.digitize(valsx, xedges)
    binindices_y = np.digitize(valsy, yedges)

    # create profile histogram
    hist = []
    for binindex_x in range(1, len(xedges)+1):
        line = []
        for binindex_y in range(1, len(yedges)+1):
            binindices_xy = (binindices_x == binindex_x) & (binindices_y == binindex_y)
            scores_bin = scores[binindices_xy]
            if len(scores_bin) > 0:
                metric_kwargs = dict()
                if weights is not None:
                    metric_kwargs["weights"] = weights[binindices_xy]
                prof_score = metric(scores_bin, **metric_kwargs)
            else:
                prof_score = 0
            line.append(prof_score)
        hist.append(line)
    hist = np.array(hist)
    hist = hist.T # had a list of columns - needs to be list of rows

    return hist, xedges, yedges


def plot_profile_2D(plotname, valsx, valsy, scores,
                    nbinsx, xmin, xmax,
                    nbinsy, ymin, ymax,
                    metric=np.mean,
                    weights=None,
                    **kwargs):

    kwargs["zlabel"] = kwargs.get("zlabel", "Profile")

    hist, xedges, yedges = get_profile_2D(
        valsx, valsy, scores,
        nbinsx, xmin, xmax,
        nbinsy, ymin, ymax,
        metric=metric, weights=weights
    )

    plot_hist_2D(plotname, xedges, yedges, hist, **kwargs)


if __name__ == "__main__":

    import sys

    from .toolkit import ClassificationProject

    import logging
    logging.basicConfig()
    logging.getLogger().setLevel(logging.DEBUG)

    from .utils import get_single_neuron_function, get_max_activation_events

    import meme
    # meme.setOptions(overrideCache="/scratch-local/nhartmann/meme_cache")

    if len(sys.argv) < 2:
        c = ClassificationProject("/project/etp/nhartmann/p/keras/021-check-low-vs-high-fewvar/all_high")
    else:
        c = ClassificationProject(sys.argv[1])

    def test_mean_signal():

        c._load_data() # untransformed

        mean_signal = get_mean_event(c.x_test, c.y_test, 1)

        print("Mean signal: ")
        for branch_index, val in enumerate(mean_signal):
            print("{:>20}: {:<10.3f}".format(c.fields[branch_index], val))

        plot_NN_vs_var_1D("met.pdf", mean_signal,
                          scorefun=c.evaluate,
                          var_index=c.fields.index("met"),
                          var_range=(0, 1000, 10),
                          var_label="met [GeV]")

        plot_NN_vs_var_1D("mt.pdf", mean_signal,
                          scorefun=c.evaluate,
                          var_index=c.fields.index("mt"),
                          var_range=(0, 500, 10),
                          var_label="mt [GeV]")

        plot_NN_vs_var_2D("mt_vs_met.pdf", means=mean_signal,
                          scorefun=c.evaluate,
                          varx_index=c.fields.index("met"),
                          vary_index=c.fields.index("mt"),
                          nbinsx=100, xmin=0, xmax=1000,
                          nbinsy=100, ymin=0, ymax=500,
                          varx_label="met [GeV]", vary_label="mt [GeV]")


        plot_NN_vs_var_2D_all("mt_vs_met_all.pdf", means=mean_signal,
                              model=c.model, transform_function=c.transform,
                              varx_index=c.fields.index("met"),
                              vary_index=c.fields.index("mt"),
                              nbinsx=100, xmin=0, xmax=1000,
                              nbinsy=100, ymin=0, ymax=500,
                              varx_label="met [GeV]", vary_label="mt [GeV]")

        input_transform = c.transform
        if hasattr(c, "get_input_list"):
            input_transform = lambda x : c.get_input_list(c.transform(x))

        plot_NN_vs_var_2D("mt_vs_met_crosscheck.pdf", means=mean_signal,
                          scorefun=get_single_neuron_function(c.model, layer=3, neuron=0, input_transform=input_transform),
                          varx_index=c.fields.index("met"),
                          vary_index=c.fields.index("mt"),
                          nbinsx=100, xmin=0, xmax=1000,
                          nbinsy=100, ymin=0, ymax=500,
                          varx_label="met [GeV]", vary_label="mt [GeV]")


    def test_max_act():

        # transformed events
        c.load(reload=True)
        ranges = [np.percentile(c.x_test[:,var_index], [1,99]) for var_index in range(len(c.fields))]

        losses, events = get_max_activation_events(c.model, ranges, ntries=100000, layer=3, neuron=0, threshold=0.2)

        events = c.inverse_transform(events)

        plot_hist_2D_events(
            "mt_vs_met_actmaxhist.pdf",
            events[:,c.fields.index("met")],
            events[:,c.fields.index("mt")],
            100, 0, 1000,
            100, 0, 500,
            varx_label="met [GeV]", vary_label="mt [GeV]",
        )

        plot_hist_2D_events(
            "mt_vs_output_actmax.pdf",
            events[:,c.fields.index("mt")],
            losses,
            100, 0, 500,
            100, 0, 1,
            varx_label="mt [GeV]", vary_label="NN output",
            log=True,
        )


    def test_cond_max_act():

        c.load(reload=True)
        ranges = [np.percentile(c.x_test[:,var_index], [1,99]) for var_index in range(len(c.fields))]

        plot_cond_avg_actmax_2D(
            "mt_vs_met_cond_actmax.pdf",
            c.model, 3, 0, ranges,
            c.fields.index("met"),
            c.fields.index("mt"),
            30, 0, 1000,
            30, 0, 500,
            transform=c.transform, inverse_transform=c.inverse_transform,
            varx_label="met [GeV]", vary_label="mt [GeV]",
        )


    def test_xtest_vs_output():

        c.load(reload=True)

        utrf_x_test = c.inverse_transform(c.x_test)

        plot_hist_2D_events(
            "mt_vs_output_signal_test.pdf",
            utrf_x_test[c.y_test==1][:,c.fields.index("mt")],
            c.scores_test[c.y_test==1].reshape(-1),
            100, 0, 1000,
            100, 0, 1,
            varx_label="mt [GeV]", vary_label="NN output",
            log=True,
        )

        plot_hist_2D_events(
            "mt_vs_met_signal.pdf",
            utrf_x_test[c.y_test==1][:,c.fields.index("met")],
            utrf_x_test[c.y_test==1][:,c.fields.index("mt")],
            100, 0, 1000,
            100, 0, 500,
            varx_label="met [GeV]",
            vary_label="mt [GeV]",
            log=True,
        )

        plot_hist_2D_events(
            "mt_vs_met_backgound.pdf",
            utrf_x_test[c.y_test==0][:,c.fields.index("met")],
            utrf_x_test[c.y_test==0][:,c.fields.index("mt")],
            100, 0, 1000,
            100, 0, 500,
            varx_label="met [GeV]",
            vary_label="mt [GeV]",
            log=True,
        )


        # plot_hist_2D_events(
        #     "apl_vs_output_actmax.pdf",
        #     events[:,c.fields.index("LepAplanarity")],
        #     losses,
        #     100, 0, 0.1,
        #     100, 0, 1,
        #     varx_label="Aplanarity", vary_label="NN output",
        # )


    def test_profile():

        c.load(reload=True)
        utrf_x_test = c.inverse_transform(c.x_test)

        plot_profile_2D(
            "mt_vs_met_profilemean_sig.pdf",
            utrf_x_test[c.y_test==1][:,c.fields.index("met")],
            utrf_x_test[c.y_test==1][:,c.fields.index("mt")],
            c.scores_test[c.y_test==1].reshape(-1),
            20, 0, 500,
            20, 0, 1000,
            varx_label="met [GeV]", vary_label="mt [GeV]",
        )

        plot_profile_2D(
            "mt_vs_met_profilemax_sig.pdf",
            utrf_x_test[c.y_test==1][:,c.fields.index("met")],
            utrf_x_test[c.y_test==1][:,c.fields.index("mt")],
            c.scores_test[c.y_test==1].reshape(-1),
            20, 0, 500,
            20, 0, 1000,
            metric=np.max,
            varx_label="met [GeV]", vary_label="mt [GeV]",
        )

    for obj in dir():
        if obj.startswith("test_") and callable(locals()[obj]):
            if (len(sys.argv) > 2) and (not sys.argv[2] == obj):
                continue
            print("Running {}".format(obj))
            locals()[obj]()