From 778a981dea833bd76b02008abd91278b505ad7c5 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <nikolai.hartmann@gmx.de>
Date: Mon, 22 Oct 2018 10:52:54 +0200
Subject: [PATCH] scale and shift example data for tests randomly

---
 test/test_toolkit.py | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/test/test_toolkit.py b/test/test_toolkit.py
index 15d148b..ecd3aa2 100644
--- a/test/test_toolkit.py
+++ b/test/test_toolkit.py
@@ -8,12 +8,22 @@ from keras.layers import GRU
 from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN
 
 def create_dataset(path):
+
+    # create example dataset with (low-weighted) noise added
     X, y = make_classification(n_samples=10000, random_state=1)
     X2 = np.random.normal(size=20*10000).reshape(-1, 20)
     y2 = np.concatenate([np.zeros(5000), np.ones(5000)])
     X = np.concatenate([X, X2])
     y = np.concatenate([y, y2])
     w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)])
+
+    # shift and scale randomly (to check if transformation is working)
+    shift = np.random.rand(20)*100
+    scale = np.random.rand(20)*1000
+    X *= scale
+    X += shift
+
+    # write to root files
     branches = ["var_{}".format(i) for i in range(len(X[0]))]
     df = pd.DataFrame(X, columns=branches)
     df["class"] = y
@@ -41,6 +51,8 @@ def test_ClassificationProject(tmp_path):
         nodes=128,
     )
     c.train(epochs=200)
+    c.plot_all_inputs()
+    c.plot_loss()
     assert min(c.history.history["val_loss"]) < 0.18
 
 
@@ -71,4 +83,6 @@ def test_ClassificationProjectRNN(tmp_path):
     )
     assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2
     c.train(epochs=200)
+    c.plot_all_inputs()
+    c.plot_loss()
     assert min(c.history.history["val_loss"]) < 0.18
-- 
GitLab