diff --git a/toolkit.py b/toolkit.py index 9605aa4b887954c75c9ef6fa95ef178f663c014e..82391fa1bb60131cb70682d2a7453616de37ff6e 100755 --- a/toolkit.py +++ b/toolkit.py @@ -75,6 +75,11 @@ class ClassificationProject(object): :param weight_expr: expression to weight the events in the loss function + :param data_dir: if given, load the data from a previous project with the given name + instead of creating it from trees. If the data is on the same + disk (and the filesystem supports it), hard links will be used, + otherwise symlinks. + :param identifiers: list of branches or expressions that uniquely identify events. This is used to store the list of training events, such that they can be marked later on, for example when @@ -152,6 +157,7 @@ class ClassificationProject(object): def _init_from_args(self, name, signal_trees, bkg_trees, branches, weight_expr, + data_dir=None, identifiers=None, selection=None, layers=3, @@ -179,6 +185,7 @@ class ClassificationProject(object): self.branches = branches self.weight_expr = weight_expr self.selection = selection + self.data_dir = data_dir if identifiers is None: identifiers = [] self.identifiers = identifiers @@ -363,6 +370,14 @@ class ClassificationProject(object): dataset_names = self.dataset_names for dataset_name in dataset_names: filename = os.path.join(self.project_dir, dataset_name+".h5") + if (self.data_dir is not None) and (not os.path.exists(filename)): + srcpath = os.path.join(self.data_dir, dataset_name+".h5") + try: + os.link(srcpath, filename) + logger.info("Created hardlink from {} to {}".format(srcpath, filename)) + except OSError: + os.symlink(srcpath, filename) + logger.info("Created symlink from {} to {}".format(srcpath, filename)) logger.info("Trying to load {} from {}".format(dataset_name, filename)) with h5py.File(filename) as hf: setattr(self, dataset_name, hf[dataset_name][:])