import pytest
import numpy as np
import root_numpy
import pandas as pd
from sklearn.datasets import make_classification
from keras.layers import GRU

from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN

def create_dataset(path):

    # create example dataset with (low-weighted) noise added
    X, y = make_classification(n_samples=10000, random_state=1)
    X2 = np.random.normal(size=20*10000).reshape(-1, 20)
    y2 = np.concatenate([np.zeros(5000), np.ones(5000)])
    X = np.concatenate([X, X2])
    y = np.concatenate([y, y2])
    w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)])

    # shift and scale randomly (to check if transformation is working)
    shift = np.random.rand(20)*100
    scale = np.random.rand(20)*1000
    X *= scale
    X += shift

    # write to root files
    branches = ["var_{}".format(i) for i in range(len(X[0]))]
    df = pd.DataFrame(X, columns=branches)
    df["class"] = y
    df["weight"] = w
    tree_path_bkg = str(path / "bkg.root")
    tree_path_sig = str(path / "sig.root")
    root_numpy.array2root(df[df["class"]==0].to_records(), tree_path_bkg)
    root_numpy.array2root(df[df["class"]==1].to_records(), tree_path_sig)
    return branches, tree_path_sig, tree_path_bkg


def test_ClassificationProject(tmp_path):
    branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path)
    c = ClassificationProject(
        str(tmp_path / "project"),
        bkg_trees = [(tree_path_bkg, "tree")],
        signal_trees = [(tree_path_sig, "tree")],
        branches = branches,
        weight_expr = "weight",
        identifiers = ["index"],
        optimizer="Adam",
        earlystopping_opts=dict(patience=5),
        dropout=0.5,
        layers=3,
        nodes=128,
    )
    c.train(epochs=200)
    c.plot_all_inputs()
    c.plot_loss()
    assert min(c.history.history["val_loss"]) < 0.18


def test_ClassificationProjectRNN(tmp_path):
    branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path)
    c = ClassificationProjectRNN(
        str(tmp_path / "project"),
        bkg_trees = [(tree_path_bkg, "tree")],
        signal_trees = [(tree_path_sig, "tree")],
        branches = branches,
        recurrent_field_names=[
            [
                ["var_1", "var_2", "var_3"],
                ["var_4", "var_5", "var_6"]
            ],
            [
                ["var_10", "var_11", "var_12"],
                ["var_13", "var_14", "var_15"]
            ],
        ],
        weight_expr = "weight",
        identifiers = ["index"],
        optimizer="Adam",
        earlystopping_opts=dict(patience=5),
        dropout=0.5,
        layers=3,
        nodes=128,
    )
    assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2
    c.train(epochs=200)
    c.plot_all_inputs()
    c.plot_loss()
    assert min(c.history.history["val_loss"]) < 0.18