Skip to content
Snippets Groups Projects
Commit 1cce79f1 authored by Nikolai's avatar Nikolai
Browse files

Make optimiser and its arguments configurable

parent 11ff3e13
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import json import json
import pickle import pickle
import importlib
import logging import logging
logger = logging.getLogger("KerasROOTClassification") logger = logging.getLogger("KerasROOTClassification")
...@@ -20,6 +21,8 @@ from keras.models import Sequential ...@@ -20,6 +21,8 @@ from keras.models import Sequential
from keras.layers import Dense from keras.layers import Dense
from keras.models import model_from_json from keras.models import model_from_json
from keras.callbacks import History from keras.callbacks import History
from keras.optimizers import SGD
import keras.optimizers
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -55,7 +58,9 @@ class KerasROOTClassification(object): ...@@ -55,7 +58,9 @@ class KerasROOTClassification(object):
out_dir="./outputs", out_dir="./outputs",
scaler_type="RobustScaler", scaler_type="RobustScaler",
step_signal=2, step_signal=2,
step_bkg=2): step_bkg=2,
optimizer="SGD",
optimizer_opts=None):
self.name = name self.name = name
self.signal_trees = signal_trees self.signal_trees = signal_trees
self.bkg_trees = bkg_trees self.bkg_trees = bkg_trees
...@@ -72,6 +77,10 @@ class KerasROOTClassification(object): ...@@ -72,6 +77,10 @@ class KerasROOTClassification(object):
self.scaler_type = scaler_type self.scaler_type = scaler_type
self.step_signal = step_signal self.step_signal = step_signal
self.step_bkg = step_bkg self.step_bkg = step_bkg
self.optimizer = optimizer
if optimizer_opts is None:
optimizer_opts = dict()
self.optimizer_opts = optimizer_opts
self.project_dir = os.path.join(self.out_dir, name) self.project_dir = os.path.join(self.out_dir, name)
...@@ -279,9 +288,9 @@ class KerasROOTClassification(object): ...@@ -279,9 +288,9 @@ class KerasROOTClassification(object):
def _dump_history(self): def _dump_history(self):
params_file = os.path.join(self.project_dir, "history_params.json") params_file = os.path.join(self.project_dir, "history_params.json")
history_file = os.path.join(self.project_dir, "history_history.json") history_file = os.path.join(self.project_dir, "history_history.json")
with open(params_file, "wb") as of: with open(params_file, "w") as of:
json.dump(self.history.params, of) json.dump(self.history.params, of)
with open(history_file, "wb") as of: with open(history_file, "w") as of:
json.dump(self.history.history, of) json.dump(self.history.history, of)
...@@ -331,11 +340,13 @@ class KerasROOTClassification(object): ...@@ -331,11 +340,13 @@ class KerasROOTClassification(object):
self._model.add(Dense(self.nodes, activation=self.activation_function)) self._model.add(Dense(self.nodes, activation=self.activation_function))
# last layer is one neuron (binary classification) # last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid')) self._model.add(Dense(1, activation='sigmoid'))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts)
logger.info("Compile model") logger.info("Compile model")
self._model.compile(optimizer='SGD', self._model.compile(optimizer=optimizer,
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy']) metrics=['accuracy'])
# dump to json for documentation # dump to json for documentation
with open(os.path.join(self.project_dir, "model.json"), "w") as of: with open(os.path.join(self.project_dir, "model.json"), "w") as of:
...@@ -518,7 +529,7 @@ class KerasROOTClassification(object): ...@@ -518,7 +529,7 @@ class KerasROOTClassification(object):
def plot_score(self): def plot_score(self):
pass pass
def plot_loss(self): def plot_loss(self):
logger.info("Plot losses") logger.info("Plot losses")
...@@ -529,10 +540,10 @@ class KerasROOTClassification(object): ...@@ -529,10 +540,10 @@ class KerasROOTClassification(object):
plt.legend(['train','test'], loc='upper left') plt.legend(['train','test'], loc='upper left')
plt.savefig(os.path.join(self.project_dir, "losses.pdf")) plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
plt.clf() plt.clf()
def plot_accuracy(self): def plot_accuracy(self):
logger.info("Plot accuracy") logger.info("Plot accuracy")
plt.plot(self.history.history['acc']) plt.plot(self.history.history['acc'])
plt.plot(self.history.history['val_acc']) plt.plot(self.history.history['val_acc'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment