#!/usr/bin/env python
"""
Write new TTrees with signal parameters as branches. For the
backgrounds the parameters are generated following the total
distribution for all signals. The discrete values for the whole ntuple
of signal parameters are counted, such that correlations between
signal parameters are taken into account.
"""

import argparse, re, os

import ROOT

from root_numpy import list_trees
from root_pandas import read_root
import numpy as np

if __name__ == "__main__":

    input_filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
    output_filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys_parametrized.root"

    param_names = ["mg", "mc", "mn"]

    param_match = "GG_oneStep_(.*?)_(.*?)_(.*?)_NoSys"

    output_signal_treename = "GG_oneStep_NoSys"

    bkg_trees = [
        "diboson_Sherpa221_NoSys",
        "singletop_NoSys",
        "ttbar_NoSys",
        "ttv_NoSys",
        "wjets_Sherpa221_NoSys",
        "zjets_Sherpa221_NoSys",
    ]

    # read in the number of events for each combination of parameters
    f = ROOT.TFile.Open(input_filename)
    count_dict = {}
    for key in f.GetListOfKeys():
        tree_name = key.GetName()
        match = re.match(param_match, tree_name)
        if match is not None:
            tree = f.Get(tree_name)
            params = tuple([float(i) for i in match.groups()])
            if not params in count_dict:
                count_dict[params] = 0
            # TODO: might be better to use sum of weights
            count_dict[params] += tree.GetEntries()
    f.Close()

    # calculate cumulative sum of counts to sample signal parameters for background from
    numbers = np.array(count_dict.keys(), dtype=np.float)
    counts = np.array(count_dict.values(), dtype=np.float)
    probs = counts/counts.sum()
    prob_bins = np.cumsum(probs)

    # read and write the rest in chunks
    if os.path.exists(output_filename):
        os.remove(output_filename)
    for tree_name in list_trees(input_filename):
        match_signal = re.match(param_match, tree_name)
        if match_signal is not None or tree_name in bkg_trees:
            print("Writing {}".format(tree_name))
            nwritten = 0
            for df in read_root(input_filename, tree_name, chunksize=100000):
                print("Writing event {}".format(nwritten))
                if match_signal is None:
                    rnd = np.random.random(len(df))
                    rnd_idx = np.digitize(rnd, prob_bins)
                    param_values = numbers[rnd_idx]
                    for param_idx, param_name in enumerate(param_names):
                        df[param_name] = param_values[:,param_idx]
                    df["training_weight"] = df["eventWeight"]*df["genWeight"]
                else:
                    for param_name, param_value in zip(param_names, match_signal.groups()):
                        df[param_name] = float(param_value)
                    df["training_weight"] = df["eventWeight"]
                if match_signal is None:
                    out_tree_name = tree_name
                else:
                    out_tree_name = output_signal_treename
                df.to_root(output_filename, mode="a", key=out_tree_name)
                nwritten += len(df)