From 1b0543801a01a8281e49e0887e89f8b2605f952f Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Fri, 3 Aug 2018 15:05:52 +0200
Subject: [PATCH] hist mode for plot_single_neuron

---
 plotting.py                   |  6 ++++--
 scripts/plot_single_neuron.py | 38 ++++++++++++++++++++++++++++-------
 2 files changed, 35 insertions(+), 9 deletions(-)

diff --git a/plotting.py b/plotting.py
index d3b306a..6e83a8f 100644
--- a/plotting.py
+++ b/plotting.py
@@ -236,12 +236,14 @@ def plot_hist_2D(plotname, xedges, yedges, hist, varx_label=None, vary_label=Non
     plt.close(fig)
 
 
-def plot_hist_2D_events(plotname, valsx, valsy, nbinsx, xmin, xmax, nbinsy, ymin, ymax, varx_label=None, vary_label=None, log=False):
+def plot_hist_2D_events(plotname, valsx, valsy, nbinsx, xmin, xmax, nbinsy, ymin, ymax,
+                        weights=None,
+                        varx_label=None, vary_label=None, log=False):
 
     xedges = np.linspace(xmin, xmax, nbinsx)
     yedges = np.linspace(ymin, ymax, nbinsy)
 
-    hist, xedges, yedges = np.histogram2d(valsx, valsy, bins=(xedges, yedges))
+    hist, xedges, yedges = np.histogram2d(valsx, valsy, bins=(xedges, yedges), weights=weights)
 
     hist = hist.T
 
diff --git a/scripts/plot_single_neuron.py b/scripts/plot_single_neuron.py
index a2ce875..bb638d8 100755
--- a/scripts/plot_single_neuron.py
+++ b/scripts/plot_single_neuron.py
@@ -9,7 +9,8 @@ from KerasROOTClassification import ClassificationProject
 from KerasROOTClassification.plotting import (
     get_mean_event,
     plot_NN_vs_var_2D,
-    plot_profile_2D
+    plot_profile_2D,
+    plot_hist_2D_events
 )
 from KerasROOTClassification.tfhelpers import get_single_neuron_function
 
@@ -18,7 +19,9 @@ parser.add_argument("project_dir")
 parser.add_argument("output_filename")
 parser.add_argument("varx")
 parser.add_argument("vary")
-parser.add_argument("-m", "--mode", choices=["mean_sig", "mean_bkg", "profile_sig", "profile_bkg"], default="mean_sig")
+parser.add_argument("-m", "--mode",
+                    choices=["mean_sig", "mean_bkg", "profile_sig", "profile_bkg", "hist_sig", "hist_bkg"],
+                    default="mean_sig")
 parser.add_argument("-l", "--layer", type=int, help="Layer index (takes last layer by default)")
 parser.add_argument("-n", "--neuron", type=int, default=0, help="Neuron index (takes first neuron by default)")
 parser.add_argument("--log", action="store_true", help="Plot in color in log scale")
@@ -41,10 +44,11 @@ if layer is None:
 varx_index = c.branches.index(args.varx)
 vary_index = c.branches.index(args.vary)
 
-x_test = c.x_test
+varx_label = args.varx
+vary_label = args.vary
 
-percentilesx = np.percentile(x_test[:,varx_index], [1,99])
-percentilesy = np.percentile(x_test[:,vary_index], [1,99])
+percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
+percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
 
 if args.xrange is not None:
     if len(args.xrange) < 3:
@@ -75,7 +79,7 @@ if args.mode.startswith("mean"):
         scorefun=get_single_neuron_function(c.model, layer, neuron, scaler=c.scaler),
         xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
         ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
-        varx_label=args.varx, vary_label=args.vary,
+        varx_label=varx_label, vary_label=vary_label,
         logscale=args.log, only_pixels=(not args.contour)
     )
 
@@ -106,6 +110,26 @@ elif args.mode.startswith("profile"):
         xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
         ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
         metric=metric_dict[args.profile_metric],
-        varx_label="met [GeV]", vary_label="mt [GeV]",
+        varx_label=varx_label, vary_label=vary_label,
         **opt_kwargs
     )
+
+elif args.mode.startswith("hist"):
+
+    if args.mode == "hist_sig":
+        class_index = 1
+    else:
+        class_index = 0
+
+    valsx = c.x_test[c.y_test==class_index][:,varx_index]
+    valsy = c.x_test[c.y_test==class_index][:,vary_index]
+    weights = c.w_test[c.y_test==class_index]
+
+    plot_hist_2D_events(
+        args.output_filename,
+        valsx, valsy,
+        xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
+        ymin=vary_range[0], ymax=vary_range[1], nbinsy=vary_range[2],
+        weights=weights,
+        varx_label=varx_label, vary_label=vary_label,
+    )
-- 
GitLab