Skip to content
Snippets Groups Projects
plotting.py 4.34 KiB
Newer Older
#!/usr/bin/env python

import os
import math

import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib.ticker import LogFormatter
import numpy as np

import meme

"""
Some further plotting functions
"""

def get_mean_event(x, y, class_label):
    return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])]


def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None):
    "Plot the NN output vs one variable with the other variables set to the given mean values"

    # example: vary var1
    print("Creating varied events (1d)")
    sequence = np.arange(*var_range)
    events = np.tile(means, len(sequence)).reshape(-1, len(means))
    events[:,var_index] = sequence

    print("Predicting scores")
    scores = scorefun(events)

    fig, ax = plt.subplots()
    ax.plot(sequence, scores)
    if var_label is not None:
        ax.set_xlabel(var_label)
    ax.set_ylabel("NN output")
    fig.savefig(plotname)


def plot_NN_vs_var_2D(plotname, means,
                      scorefun,
                      var1_index, var1_range,
                      var2_index, var2_range,
                      var1_label=None,
                      var2_label=None,
                      logscale=False,
                      ncontours=20,
                      black_contourlines=False):

    print("Creating varied events (2d)")
    # example: vary var1 vs var2
    sequence1 = np.arange(*var1_range)
    sequence2 = np.arange(*var2_range)
    # the following is a 2d array of events (so effectively 3D)
    events = np.tile(means, len(sequence1)*len(sequence2)).reshape(len(sequence2), len(sequence1), -1)

    # fill in the varied values
    # (probably there is a more clever way, but sufficient here)
    for i, y in enumerate(sequence2):
        for j, x in enumerate(sequence1):
            events[i][j][var1_index] = x
            events[i][j][var2_index] = y

    # convert back into 1d array
    events = events.reshape(-1, len(means))

    print("Predicting scores")
    scores = scorefun(events)

    # convert scores into 2d array
    scores = scores.reshape(len(sequence2), len(sequence1))

    fig, ax = plt.subplots()

    zmin = np.min(scores)
    zmax = np.max(scores)

    if logscale:
        lvls = np.logspace(math.log10(zmin), math.log10(zmax), ncontours)
        pcm = ax.contourf(sequence1, sequence2, scores, levels=lvls, norm=matplotlib.colors.LogNorm(vmin=zmin, vmax=zmax))
        if black_contourlines:
            ax.contour(sequence1, sequence2, scores, levels=lvls, colors="k", linewidths=1)
        l_f = LogFormatter(10, labelOnlyBase=False, minor_thresholds=(np.inf, np.inf))
        cbar = fig.colorbar(pcm, ax=ax, extend='max', ticks=lvls, format=l_f)
    else:
        pcm = ax.contourf(sequence1, sequence2, scores, ncontours, norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
        if black_contourlines:
            ax.contour(sequence1, sequence2, scores, ncontours, colors="k", linewidths=1)
        cbar = fig.colorbar(pcm, ax=ax, extend='max')

    cbar.set_label("NN output")
    if var1_label is not None:
        ax.set_xlabel(var1_label)
    if var2_label is not None:
        ax.set_ylabel(var2_label)
    fig.savefig(plotname)



if __name__ == "__main__":

    from .toolkit import ClassificationProject

    c = ClassificationProject(os.path.expanduser("~/p/scripts/keras/008-allhighlevel/all_highlevel_985"))

    mean_signal = get_mean_event(c.x_test, c.y_test, 1)

    print("Mean signal: ")
    for branch_index, val in enumerate(mean_signal):
        print("{:>20}: {:<10.3f}".format(c.branches[branch_index], val))

    plot_NN_vs_var_1D("met.pdf", mean_signal,
                      scorefun=c.evaluate,
                      var_index=c.branches.index("met"),
                      var_range=(0, 1000, 10),
                      var_label="met [GeV]")

    plot_NN_vs_var_1D("mt.pdf", mean_signal,
                      scorefun=c.evaluate,
                      var_index=c.branches.index("mt"),
                      var_range=(0, 500, 10),
                      var_label="mt [GeV]")

    plot_NN_vs_var_2D("mt_vs_met.pdf", means=mean_signal,
                      scorefun=c.evaluate,
                      var1_index=c.branches.index("met"), var1_range=(0, 1000, 10),
                      var2_index=c.branches.index("mt"), var2_range=(0, 500, 10),
                      var1_label="met [GeV]", var2_label="mt [GeV]")