#!/usr/bin/env python

import os
import json

import logging
logger = logging.getLogger("KerasROOTClassification")

from root_numpy import tree2array, rec2array
import numpy as np
import pandas as pd
import h5py
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
from sklearn.metrics import roc_curve

from keras.models import Sequential
from keras.layers import Dense
from keras.models import model_from_json
import matplotlib.pyplot as plt

# configure number of cores
# this doesn't seem to work, but at least with these settings keras only uses 4 processes
import tensorflow as tf
from keras import backend as K
num_cores = 1
config = tf.ConfigProto(intra_op_parallelism_threads=num_cores,
                        device_count = {'CPU': num_cores})
session = tf.Session(config=config)

import ROOT

class KerasROOTClassification:

    dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"]

    def __init__(self, name,
                 signal_trees, bkg_trees, branches, weight_expr, identifiers,
                 layers=3, nodes=64, batch_size=128, activation_function='relu', out_dir="./outputs"):
        self.name = name
        self.signal_trees = signal_trees
        self.bkg_trees = bkg_trees
        self.branches = branches
        self.weight_expr = weight_expr
        self.identifiers = identifiers
        self.layers = layers
        self.nodes = nodes
        self.batch_size = batch_size
        self.activation_function = activation_function
        self.out_dir = out_dir

        self.project_dir = os.path.join(self.out_dir, name)

        if not os.path.exists(self.out_dir):

        if not os.path.exists(self.project_dir):

        self.s_train = None
        self.b_train = None
        self.s_test = None
        self.b_test = None
        self.x_train = None
        self.x_test = None
        self.y_train = None
        self.y_test = None

        self.s_eventlist_train = None
        self.b_eventlist_train = None

        self._scaler = None
        self._class_weight = None
        self._bkg_weights = None
        self._sig_weights = None
        self._model = None

        self.score_train = None
        self.score_test = None

        # track the number of epochs this model has been trained
        self.total_epochs = 0

        self.data_loaded = False
        self.data_transformed = False

    def _load_data(self):



        except KeyError:

            logger.info("Couldn't load all datasets - reading from ROOT trees")

            # Read signal and background trees into structured numpy arrays
            signal_chain = ROOT.TChain()
            bkg_chain = ROOT.TChain()
            for filename, treename in self.signal_trees:
                signal_chain.AddFile(filename, -1, treename)
            for filename, treename in self.bkg_trees:
                bkg_chain.AddFile(filename, -1, treename)
            self.s_train = tree2array(signal_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
            self.b_train = tree2array(bkg_chain, branches=self.branches+[self.weight_expr]+self.identifiers, start=0, step=2)
            self.s_test = tree2array(signal_chain, branches=self.branches+[self.weight_expr], start=1, step=2)
            self.b_test = tree2array(bkg_chain, branches=self.branches+[self.weight_expr], start=1, step=2)

            self.s_eventlist_train = self.s_train[self.identifiers]
            self.b_eventlist_train = self.b_train[self.identifiers]

            # now we don't need the identifiers anymore
            self.s_train = self.s_train[self.branches+[self.weight_expr]]
            self.b_train = self.b_train[self.branches+[self.weight_expr]]

            # create x (input), y (target) and w (weights) arrays
            # the first block will be signals, the second block backgrounds
            self.x_train = rec2array(self.s_train[self.branches])
            self.x_train = np.concatenate((self.x_train, rec2array(self.b_train[self.branches])))
            self.x_test = rec2array(self.s_test[self.branches])
            self.x_test = np.concatenate((self.x_test, rec2array(self.b_test[self.branches])))
            self.w_train = self.s_train[self.weight_expr]
            self.w_train = np.concatenate((self.w_train, self.b_train[self.weight_expr]))
            self.w_test = self.s_test[self.weight_expr]
            self.w_test = np.concatenate((self.w_test, self.b_test[self.weight_expr]))

            self.y_train = np.empty(len(self.x_train))
            self.y_train[:len(self.s_train)] = 1
            self.y_train[len(self.s_train):] = 0
            self.y_test = np.empty(len(self.x_test))
            self.y_test[:len(self.s_test)] = 1
            self.y_test[len(self.s_test):] = 0

            logger.info("Writing to hdf5")

        self.data_loaded = True

    def _dump_training_list(self):
        s_eventlist = pd.DataFrame(self.s_train[self.identifiers])
        b_eventlist = pd.DataFrame(self.b_train[self.identifiers])

        s_eventlist.to_csv(os.path.join(self.project_dir, "s_eventlist_train.csv"))
        s_eventlist.to_csv(os.path.join(self.project_dir, "b_eventlist_train.csv"))

    def _dump_to_hdf5(self):
        for dataset_name in self.dataset_names:
            with h5py.File(os.path.join(self.project_dir, dataset_name+".h5"), "w") as hf:
                hf.create_dataset(dataset_name, data=getattr(self, dataset_name))

    def _load_from_hdf5(self):
        for dataset_name in self.dataset_names:
            filename = os.path.join(self.project_dir, dataset_name+".h5")
            logger.info("Trying to load {} from {}".format(dataset_name, filename))
            with h5py.File(filename) as hf:
                setattr(self, dataset_name,  hf[dataset_name][:])
        logger.info("Data loaded")

    def scaler(self):
        # create the scaler (and fit to training data) if not existent
        if self._scaler is None:
            filename = os.path.join(self.project_dir, "scaler.pkl")
                self._scaler = joblib.load(filename)
                logger.info("Loaded existing StandardScaler from {}".format(filename))
            except IOError:
                logger.info("Creating new StandardScaler")
                self._scaler = StandardScaler()
                logger.info("Fitting StandardScaler to training data")
                # i think this would refit to test data (and overwrite the parameters)
                # probably we either want to fit only training data or training and test data together
                # logger.info("Fitting StandardScaler to test data")
                # self._scaler.fit(self.x_test)
                joblib.dump(self._scaler, filename)
        return self._scaler

    def _transform_data(self):
        if not self.data_transformed:
            # todo: what to do about the outliers? Where do they come from?
            logger.debug("training data before transformation: {}".format(self.x_train))
            logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
            logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
            self.x_train = self.scaler.transform(self.x_train)
            logger.debug("training data after transformation: {}".format(self.x_train))
            self.x_test = self.scaler.transform(self.x_test)
            self.data_transformed = True

    def _read_info(self, key, default):
        filename = os.path.join(self.project_dir, "info.json")
        if not os.path.exists(filename):
            with open(filename, "w") as of:
                json.dump({}, of)
        with open(filename) as f:
            info = json.load(f)
        return info.get(key, default)

    def _write_info(self, key, value):
        filename = os.path.join(self.project_dir, "info.json")
        with open(filename) as f:
            info = json.load(f)
        info[key] = value
        with open(filename, "w") as of:
            json.dump(info, of)

    def model(self):
        "Simple MLP"

        if self._model is None:

            self._model = Sequential()

            # first hidden layer
            self._model.add(Dense(self.nodes, input_dim=len(self.branches), activation=self.activation_function))
            # the other hidden layers
            for layer_number in range(self.layers-1):
                self._model.add(Dense(self.nodes, activation=self.activation_function))
            # last layer is one neuron (binary classification)
            self._model.add(Dense(1, activation='sigmoid'))
            logger.info("Compile model")

            # dump to json for documentation
            with open(os.path.join(self.project_dir, "model.json"), "w") as of:

        return self._model

    def class_weight(self):
        if self._class_weight is None:
            sumw_bkg = np.sum(self.w_train[self.y_train == 0])
            sumw_sig = np.sum(self.w_train[self.y_train == 1])
            self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
        return self._class_weight

    def train(self, epochs=10):

        if not self.data_loaded:

        if not self.data_transformed:

        for branch_index, branch in enumerate(self.branches):

            self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
            logger.info("Weights found and loaded")
            logger.info("Continue training")
        except IOError:
            logger.info("No weights found, starting completely new training")

        self.total_epochs = self._read_info("epochs", 0)

        logger.info("Train model")
                       # the reshape might be unnescessary here
                       self.y_train.reshape(-1, 1),

        logger.info("Save weights")
        self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))

        self.total_epochs += epochs
        self._write_info("epochs", self.total_epochs)

        logger.info("Create scores for ROC curve")
        self.scores_test = self.model.predict(self.x_test)
        self.scores_train = self.model.predict(self.x_train)

    def evaluate(self):

    def write_friend_tree(self):

    def bkg_weights(self):
        class weights multiplied by event weights (for plotting)
        TODO: find a better way to do this
        if self._bkg_weights is None:
            logger.debug("Calculating background weights for plotting")
            self._bkg_weights = np.empty(sum(self.y_train == 0))
            self._bkg_weights *= self.w_train[self.y_train == 0]
        return self._bkg_weights

    def sig_weights(self):
        class weights multiplied by event weights (for plotting)
        TODO: find a better way to do this
        if self._sig_weights is None:
            logger.debug("Calculating signal weights for plotting")
            self._sig_weights = np.empty(sum(self.y_train == 1))
            self._sig_weights *= self.w_train[self.y_train == 1]
        return self._sig_weights

    def plot_input(self, var_index):
        "plot a single input variable"
        branch = self.branches[var_index]
        fig, ax = plt.subplots()
        bkg = self.x_train[:,var_index][self.y_train == 0]
        sig = self.x_train[:,var_index][self.y_train == 1]
        logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
        logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
        ax.hist(bkg, color="b", alpha=0.5, bins=50, weights=self.bkg_weights)
        ax.hist(sig, color="r", alpha=0.5, bins=50, weights=self.sig_weights)
        ax.set_xlabel(branch+" (transformed)")
        plot_dir = os.path.join(self.project_dir, "plots")
        if not os.path.exists(plot_dir):
        fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index)))

    def plot_ROC(self):

        logger.info("Plot ROC curve")
        fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)

        plt.grid(color='gray', linestyle='--', linewidth=1)
        plt.plot(fpr, tpr, label='NN')
        plt.plot([0,1],[0,1], linestyle='--', color='black', label='Luck')
        plt.xlabel("False positive rate (background rejection)")
        plt.ylabel("True positive rate (signal efficiency)")
        plt.title('Receiver operating characteristic')
        plt.legend(loc='lower left', framealpha=1.0)

        plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))

    def plot_score(self):

if __name__ == "__main__":


    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = KerasROOTClassification("test",
                                signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                                bkg_trees = [(filename, "ttbar_NoSys"),
                                             (filename, "wjets_Sherpa221_NoSys")
                                branches = ["met", "mt"],
                                weight_expr = "eventWeight*genWeight",
                                identifiers = ["DatasetNumber", "EventNumber"])
