diff --git a/toolkit.py b/toolkit.py index 8424fd5480ef4db6c826dfb138715931cef4503e..99a23a5b76615a548f6d2d4e6a079582504b3d6c 100755 --- a/toolkit.py +++ b/toolkit.py @@ -50,9 +50,20 @@ class KerasROOTClassification(object): dataset_names_tree = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"] def __init__(self, name, *args, **kwargs): - self._init_from_args(name, *args, **kwargs) - with open(os.path.join(self.project_dir, "options.json"), "w") as of: - json.dump(dict(args=args, kwargs=kwargs), of) + if len(args) < 1 and len(kwargs) < 1: + # if no further arguments given, interpret as directory name + self._init_from_dir(name) + else: + # otherwise initialise new project + self._init_from_args(name, *args, **kwargs) + with open(os.path.join(self.project_dir, "options.json"), "w") as of: + json.dump(dict(args=args, kwargs=kwargs), of) + + + def _init_from_dir(self, dirname): + with open(os.path.join(dirname, "options.json")) as f: + options = json.load(f) + self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"]) def _init_from_args(self, name,