Skip to content
Snippets Groups Projects
Commit 649d4d96 authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

adding tests

parent e7a75351
No related branches found
No related tags found
No related merge requests found
import pytest
import numpy as np
import root_numpy
import pandas as pd
from sklearn.datasets import make_classification
from keras.layers import GRU
from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN
def create_dataset(path):
X, y = make_classification(n_samples=10000, random_state=1)
X2 = np.random.normal(size=20*10000).reshape(-1, 20)
y2 = np.concatenate([np.zeros(5000), np.ones(5000)])
X = np.concatenate([X, X2])
y = np.concatenate([y, y2])
w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)])
branches = ["var_{}".format(i) for i in range(len(X[0]))]
df = pd.DataFrame(X, columns=branches)
df["class"] = y
df["weight"] = w
tree_path_bkg = str(path / "bkg.root")
tree_path_sig = str(path / "sig.root")
root_numpy.array2root(df[df["class"]==0].to_records(), tree_path_bkg)
root_numpy.array2root(df[df["class"]==1].to_records(), tree_path_sig)
return branches, tree_path_sig, tree_path_bkg
def test_ClassificationProject(tmp_path):
branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path)
c = ClassificationProject(
str(tmp_path / "project"),
bkg_trees = [(tree_path_bkg, "tree")],
signal_trees = [(tree_path_sig, "tree")],
branches = branches,
weight_expr = "weight",
identifiers = ["index"],
optimizer="Adam",
earlystopping_opts=dict(patience=5),
dropout=0.5,
layers=3,
nodes=128,
)
c.train(epochs=200)
assert min(c.history.history["val_loss"]) < 0.18
def test_ClassificationProjectRNN(tmp_path):
branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path)
c = ClassificationProjectRNN(
str(tmp_path / "project"),
bkg_trees = [(tree_path_bkg, "tree")],
signal_trees = [(tree_path_sig, "tree")],
branches = branches,
recurrent_field_names=[
[
["var_1", "var_2", "var_3"],
["var_4", "var_5", "var_6"]
],
[
["var_10", "var_11", "var_12"],
["var_13", "var_14", "var_15"]
],
],
weight_expr = "weight",
identifiers = ["index"],
optimizer="Adam",
earlystopping_opts=dict(patience=5),
dropout=0.5,
layers=3,
nodes=128,
)
assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2
c.train(epochs=200)
assert min(c.history.history["val_loss"]) < 0.18
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment