diff --git a/toolkit.py b/toolkit.py index 59b08b749e7709a6d4d470954fe1fe245594f039..eb5b02c4a0c94cac86d944257baab8f75f977b0a 100755 --- a/toolkit.py +++ b/toolkit.py @@ -47,20 +47,26 @@ class KerasROOTClassification(object): dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"] - def __init__(self, name, - signal_trees, bkg_trees, branches, weight_expr, identifiers, - selection=None, - layers=3, - nodes=64, - batch_size=128, - validation_split=0.33, - activation_function='relu', - out_dir="./outputs", - scaler_type="RobustScaler", - step_signal=2, - step_bkg=2, - optimizer="SGD", - optimizer_opts=None): + def __init__(self, name, *args, **kwargs): + self._init_from_args(name, *args, **kwargs) + with open(os.path.join(self.project_dir, "options.json"), "w") as of: + json.dump(dict(args=args, kwargs=kwargs), of) + + + def _init_from_args(self, name, + signal_trees, bkg_trees, branches, weight_expr, identifiers, + selection=None, + layers=3, + nodes=64, + batch_size=128, + validation_split=0.33, + activation_function='relu', + out_dir="./outputs", + scaler_type="RobustScaler", + step_signal=2, + step_bkg=2, + optimizer="SGD", + optimizer_opts=None): self.name = name self.signal_trees = signal_trees self.bkg_trees = bkg_trees