#!/usr/bin/env python

import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
import toolkit
from toolkit import KerasROOTClassification

def init_model(geneparam):
    
    nb_layers = geneparam['nb_layers']
    nb_neurons = geneparam['nb_neurons']
    activation = geneparam['activation']
    optimizer = geneparam['optimizer']
    #lr = network['lr']
    #decay = network['decay']
    #momentum = network['momentum']

    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
    
    c = KerasROOTClassification("",
                                signal_trees = [(filename, "GG_oneStep_1545_1265_985_NoSys")],
                                bkg_trees = [(filename, "ttbar_NoSys"),
                                     (filename, "wjets_Sherpa221_NoSys"),
                                     (filename, "zjets_Sherpa221_NoSys"),
                                     (filename, "diboson_Sherpa221_NoSys"),
                                     (filename, "ttv_NoSys"),
                                     (filename, "singletop_NoSys")
                                ],
                                dumping_enabled=False,
                                optimizer=optimizer,
                                layers=nb_layers,
                                nodes=nb_neurons,
                                activation_function=activation,
                               # optimizer_opts=dict(lr=lr, decay=decay,
                               #     momentum=momentum),
                                earlystopping_opts=dict(monitor='val_loss',
                                    min_delta=0, patience=2, verbose=0, mode='auto'),
                                # optimizer="Adam",
                                selection="lep1Pt<5000", # cut out a few very weird outliers
                                branches = ["met", "mt"],
                                weight_expr = "eventWeight*genWeight",
                                identifiers = ["DatasetNumber", "EventNumber"],
                                step_bkg = 100)
    return c

def train_and_score(geneparam):
    model = init_model(geneparam)

    model.train(epochs=20)

    score = model.score

    return score[1]  # 1 is accuracy. 0 is loss.