Skip to content
Snippets Groups Projects
Commit 1cce79f1 authored by Nikolai's avatar Nikolai
Browse files

Make optimiser and its arguments configurable

parent 11ff3e13
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
import os
import json
import pickle
import importlib
import logging
logger = logging.getLogger("KerasROOTClassification")
......@@ -20,6 +21,8 @@ from keras.models import Sequential
from keras.layers import Dense
from keras.models import model_from_json
from keras.callbacks import History
from keras.optimizers import SGD
import keras.optimizers
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
......@@ -55,7 +58,9 @@ class KerasROOTClassification(object):
out_dir="./outputs",
scaler_type="RobustScaler",
step_signal=2,
step_bkg=2):
step_bkg=2,
optimizer="SGD",
optimizer_opts=None):
self.name = name
self.signal_trees = signal_trees
self.bkg_trees = bkg_trees
......@@ -72,6 +77,10 @@ class KerasROOTClassification(object):
self.scaler_type = scaler_type
self.step_signal = step_signal
self.step_bkg = step_bkg
self.optimizer = optimizer
if optimizer_opts is None:
optimizer_opts = dict()
self.optimizer_opts = optimizer_opts
self.project_dir = os.path.join(self.out_dir, name)
......@@ -279,9 +288,9 @@ class KerasROOTClassification(object):
def _dump_history(self):
params_file = os.path.join(self.project_dir, "history_params.json")
history_file = os.path.join(self.project_dir, "history_history.json")
with open(params_file, "wb") as of:
with open(params_file, "w") as of:
json.dump(self.history.params, of)
with open(history_file, "wb") as of:
with open(history_file, "w") as of:
json.dump(self.history.history, of)
......@@ -331,11 +340,13 @@ class KerasROOTClassification(object):
self._model.add(Dense(self.nodes, activation=self.activation_function))
# last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid'))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts)
logger.info("Compile model")
self._model.compile(optimizer='SGD',
loss='binary_crossentropy',
metrics=['accuracy'])
self._model.compile(optimizer=optimizer,
loss='binary_crossentropy',
metrics=['accuracy'])
# dump to json for documentation
with open(os.path.join(self.project_dir, "model.json"), "w") as of:
......@@ -518,7 +529,7 @@ class KerasROOTClassification(object):
def plot_score(self):
pass
def plot_loss(self):
logger.info("Plot losses")
......@@ -529,10 +540,10 @@ class KerasROOTClassification(object):
plt.legend(['train','test'], loc='upper left')
plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
plt.clf()
def plot_accuracy(self):
logger.info("Plot accuracy")
plt.plot(self.history.history['acc'])
plt.plot(self.history.history['val_acc'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment