diff --git a/test/test_toolkit.py b/test/test_toolkit.py
new file mode 100644
index 0000000000000000000000000000000000000000..15d148be02964466ac4cba0b4ca45552dcc176d3
--- /dev/null
+++ b/test/test_toolkit.py
@@ -0,0 +1,74 @@
+import pytest
+import numpy as np
+import root_numpy
+import pandas as pd
+from sklearn.datasets import make_classification
+from keras.layers import GRU
+
+from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN
+
+def create_dataset(path):
+    X, y = make_classification(n_samples=10000, random_state=1)
+    X2 = np.random.normal(size=20*10000).reshape(-1, 20)
+    y2 = np.concatenate([np.zeros(5000), np.ones(5000)])
+    X = np.concatenate([X, X2])
+    y = np.concatenate([y, y2])
+    w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)])
+    branches = ["var_{}".format(i) for i in range(len(X[0]))]
+    df = pd.DataFrame(X, columns=branches)
+    df["class"] = y
+    df["weight"] = w
+    tree_path_bkg = str(path / "bkg.root")
+    tree_path_sig = str(path / "sig.root")
+    root_numpy.array2root(df[df["class"]==0].to_records(), tree_path_bkg)
+    root_numpy.array2root(df[df["class"]==1].to_records(), tree_path_sig)
+    return branches, tree_path_sig, tree_path_bkg
+
+
+def test_ClassificationProject(tmp_path):
+    branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path)
+    c = ClassificationProject(
+        str(tmp_path / "project"),
+        bkg_trees = [(tree_path_bkg, "tree")],
+        signal_trees = [(tree_path_sig, "tree")],
+        branches = branches,
+        weight_expr = "weight",
+        identifiers = ["index"],
+        optimizer="Adam",
+        earlystopping_opts=dict(patience=5),
+        dropout=0.5,
+        layers=3,
+        nodes=128,
+    )
+    c.train(epochs=200)
+    assert min(c.history.history["val_loss"]) < 0.18
+
+
+def test_ClassificationProjectRNN(tmp_path):
+    branches, tree_path_sig, tree_path_bkg = create_dataset(tmp_path)
+    c = ClassificationProjectRNN(
+        str(tmp_path / "project"),
+        bkg_trees = [(tree_path_bkg, "tree")],
+        signal_trees = [(tree_path_sig, "tree")],
+        branches = branches,
+        recurrent_field_names=[
+            [
+                ["var_1", "var_2", "var_3"],
+                ["var_4", "var_5", "var_6"]
+            ],
+            [
+                ["var_10", "var_11", "var_12"],
+                ["var_13", "var_14", "var_15"]
+            ],
+        ],
+        weight_expr = "weight",
+        identifiers = ["index"],
+        optimizer="Adam",
+        earlystopping_opts=dict(patience=5),
+        dropout=0.5,
+        layers=3,
+        nodes=128,
+    )
+    assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2
+    c.train(epochs=200)
+    assert min(c.history.history["val_loss"]) < 0.18