From 778a981dea833bd76b02008abd91278b505ad7c5 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <nikolai.hartmann@gmx.de> Date: Mon, 22 Oct 2018 10:52:54 +0200 Subject: [PATCH] scale and shift example data for tests randomly --- test/test_toolkit.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_toolkit.py b/test/test_toolkit.py index 15d148b..ecd3aa2 100644 --- a/test/test_toolkit.py +++ b/test/test_toolkit.py @@ -8,12 +8,22 @@ from keras.layers import GRU from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN def create_dataset(path): + + # create example dataset with (low-weighted) noise added X, y = make_classification(n_samples=10000, random_state=1) X2 = np.random.normal(size=20*10000).reshape(-1, 20) y2 = np.concatenate([np.zeros(5000), np.ones(5000)]) X = np.concatenate([X, X2]) y = np.concatenate([y, y2]) w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)]) + + # shift and scale randomly (to check if transformation is working) + shift = np.random.rand(20)*100 + scale = np.random.rand(20)*1000 + X *= scale + X += shift + + # write to root files branches = ["var_{}".format(i) for i in range(len(X[0]))] df = pd.DataFrame(X, columns=branches) df["class"] = y @@ -41,6 +51,8 @@ def test_ClassificationProject(tmp_path): nodes=128, ) c.train(epochs=200) + c.plot_all_inputs() + c.plot_loss() assert min(c.history.history["val_loss"]) < 0.18 @@ -71,4 +83,6 @@ def test_ClassificationProjectRNN(tmp_path): ) assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2 c.train(epochs=200) + c.plot_all_inputs() + c.plot_loss() assert min(c.history.history["val_loss"]) < 0.18 -- GitLab