Skip to content
Snippets Groups Projects
Commit 778a981d authored by Nikolai Hartmann's avatar Nikolai Hartmann
Browse files

scale and shift example data for tests randomly

parent 7d32696b
No related branches found
No related tags found
No related merge requests found
...@@ -8,12 +8,22 @@ from keras.layers import GRU ...@@ -8,12 +8,22 @@ from keras.layers import GRU
from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN
def create_dataset(path): def create_dataset(path):
# create example dataset with (low-weighted) noise added
X, y = make_classification(n_samples=10000, random_state=1) X, y = make_classification(n_samples=10000, random_state=1)
X2 = np.random.normal(size=20*10000).reshape(-1, 20) X2 = np.random.normal(size=20*10000).reshape(-1, 20)
y2 = np.concatenate([np.zeros(5000), np.ones(5000)]) y2 = np.concatenate([np.zeros(5000), np.ones(5000)])
X = np.concatenate([X, X2]) X = np.concatenate([X, X2])
y = np.concatenate([y, y2]) y = np.concatenate([y, y2])
w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)]) w = np.concatenate([np.ones(10000), 0.01*np.ones(10000)])
# shift and scale randomly (to check if transformation is working)
shift = np.random.rand(20)*100
scale = np.random.rand(20)*1000
X *= scale
X += shift
# write to root files
branches = ["var_{}".format(i) for i in range(len(X[0]))] branches = ["var_{}".format(i) for i in range(len(X[0]))]
df = pd.DataFrame(X, columns=branches) df = pd.DataFrame(X, columns=branches)
df["class"] = y df["class"] = y
...@@ -41,6 +51,8 @@ def test_ClassificationProject(tmp_path): ...@@ -41,6 +51,8 @@ def test_ClassificationProject(tmp_path):
nodes=128, nodes=128,
) )
c.train(epochs=200) c.train(epochs=200)
c.plot_all_inputs()
c.plot_loss()
assert min(c.history.history["val_loss"]) < 0.18 assert min(c.history.history["val_loss"]) < 0.18
...@@ -71,4 +83,6 @@ def test_ClassificationProjectRNN(tmp_path): ...@@ -71,4 +83,6 @@ def test_ClassificationProjectRNN(tmp_path):
) )
assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2 assert sum([isinstance(layer, GRU) for layer in c.model.layers]) == 2
c.train(epochs=200) c.train(epochs=200)
c.plot_all_inputs()
c.plot_loss()
assert min(c.history.history["val_loss"]) < 0.18 assert min(c.history.history["val_loss"]) < 0.18
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment