diff --git a/test/test_toolkit.py b/test/test_toolkit.py new file mode 100644 index 0000000000000000000000000000000000000000..15d148be02964466ac4cba0b4ca45552dcc176d3 --- /dev/null +++ b/test/test_toolkit.py @@ -0,0 +1,74 @@ +import pytest +import numpy as np +import root_numpy +import pandas as pd +from sklearn.datasets import make_classification +from keras.layers import GRU + +from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN + +def create_dataset(path): + X, y = make_classification(n_samples=10000, random_state=1) + X2 = np.random.normal(size=20*10000).reshape(-1, 20) + y2 = np.concatenate([np.zeros(5000), np.ones(5000)]) + X = np.concatenate([X, X2]) + y = np.concatenate([y, y2]) + w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)]) + branches = ["var_{}".format(i) for i in range(len(X[0]))] + df = pd.DataFrame(X, columns=branches) + df["class"] = y + df["weight"] = w + tree_path_bkg = str(path / "bkg.root") + tree_path_sig = str(path / "sig.root") + root_numpy.array2root(df[df["class"]==0].to_records(), tree_path_bkg) + root_numpy.array2root(df[df["class"]==1].to_records(), tree_path_sig) + return branches, tree_path_sig, tree_path_bkg + + +def test_ClassificationProject(tmp_path): + branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path) + c = ClassificationProject( + str(tmp_path / "project"), + bkg_trees = [(tree_path_bkg, "tree")], + signal_trees = [(tree_path_sig, "tree")], + branches = branches, + weight_expr = "weight", + identifiers = ["index"], + optimizer="Adam", + earlystopping_opts=dict(patience=5), + dropout=0.5, + layers=3, + nodes=128, + ) + c.train(epochs=200) + assert min(c.history.history["val_loss"]) < 0.18 + + +def test_ClassificationProjectRNN(tmp_path): + branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path) + c = ClassificationProjectRNN( + str(tmp_path / "project"), + bkg_trees = [(tree_path_bkg, "tree")], + signal_trees = [(tree_path_sig, "tree")], + branches = branches, + recurrent_field_names=[ + [ + ["var_1", "var_2", "var_3"], + ["var_4", "var_5", "var_6"] + ], + [ + ["var_10", "var_11", "var_12"], + ["var_13", "var_14", "var_15"] + ], + ], + weight_expr = "weight", + identifiers = ["index"], + optimizer="Adam", + earlystopping_opts=dict(patience=5), + dropout=0.5, + layers=3, + nodes=128, + ) + assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2 + c.train(epochs=200) + assert min(c.history.history["val_loss"]) < 0.18