diff --git a/toolkit.py b/toolkit.py index 789467549887f87d25c988713db4fc3429e294a5..59b08b749e7709a6d4d470954fe1fe245594f039 100755 --- a/toolkit.py +++ b/toolkit.py @@ -3,6 +3,7 @@ import os import json import pickle +import importlib import logging logger = logging.getLogger("KerasROOTClassification") @@ -20,6 +21,8 @@ from keras.models import Sequential from keras.layers import Dense from keras.models import model_from_json from keras.callbacks import History +from keras.optimizers import SGD +import keras.optimizers import matplotlib.pyplot as plt import matplotlib.pyplot as plt @@ -55,7 +58,9 @@ class KerasROOTClassification(object): out_dir="./outputs", scaler_type="RobustScaler", step_signal=2, - step_bkg=2): + step_bkg=2, + optimizer="SGD", + optimizer_opts=None): self.name = name self.signal_trees = signal_trees self.bkg_trees = bkg_trees @@ -72,6 +77,10 @@ class KerasROOTClassification(object): self.scaler_type = scaler_type self.step_signal = step_signal self.step_bkg = step_bkg + self.optimizer = optimizer + if optimizer_opts is None: + optimizer_opts = dict() + self.optimizer_opts = optimizer_opts self.project_dir = os.path.join(self.out_dir, name) @@ -279,9 +288,9 @@ class KerasROOTClassification(object): def _dump_history(self): params_file = os.path.join(self.project_dir, "history_params.json") history_file = os.path.join(self.project_dir, "history_history.json") - with open(params_file, "wb") as of: + with open(params_file, "w") as of: json.dump(self.history.params, of) - with open(history_file, "wb") as of: + with open(history_file, "w") as of: json.dump(self.history.history, of) @@ -331,11 +340,13 @@ class KerasROOTClassification(object): self._model.add(Dense(self.nodes, activation=self.activation_function)) # last layer is one neuron (binary classification) self._model.add(Dense(1, activation='sigmoid')) - + logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) + Optimizer = getattr(keras.optimizers, self.optimizer) + optimizer = Optimizer(**self.optimizer_opts) logger.info("Compile model") - self._model.compile(optimizer='SGD', - loss='binary_crossentropy', - metrics=['accuracy']) + self._model.compile(optimizer=optimizer, + loss='binary_crossentropy', + metrics=['accuracy']) # dump to json for documentation with open(os.path.join(self.project_dir, "model.json"), "w") as of: @@ -518,7 +529,7 @@ class KerasROOTClassification(object): def plot_score(self): pass - + def plot_loss(self): logger.info("Plot losses") @@ -529,10 +540,10 @@ class KerasROOTClassification(object): plt.legend(['train','test'], loc='upper left') plt.savefig(os.path.join(self.project_dir, "losses.pdf")) plt.clf() - + def plot_accuracy(self): - + logger.info("Plot accuracy") plt.plot(self.history.history['acc']) plt.plot(self.history.history['val_acc'])