From 58e11537df8f476953d1406a151a804d43771996 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Mon, 11 Jun 2018 11:46:45 +0200 Subject: [PATCH] Adding option data_dir to automatically reuse (via hard or symlink) previously created arrays --- toolkit.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/toolkit.py b/toolkit.py index 9605aa4..82391fa 100755 --- a/toolkit.py +++ b/toolkit.py @@ -75,6 +75,11 @@ class ClassificationProject(object): :param weight_expr: expression to weight the events in the loss function + :param data_dir: if given, load the data from a previous project with the given name + instead of creating it from trees. If the data is on the same + disk (and the filesystem supports it), hard links will be used, + otherwise symlinks. + :param identifiers: list of branches or expressions that uniquely identify events. This is used to store the list of training events, such that they can be marked later on, for example when @@ -152,6 +157,7 @@ class ClassificationProject(object): def _init_from_args(self, name, signal_trees, bkg_trees, branches, weight_expr, + data_dir=None, identifiers=None, selection=None, layers=3, @@ -179,6 +185,7 @@ class ClassificationProject(object): self.branches = branches self.weight_expr = weight_expr self.selection = selection + self.data_dir = data_dir if identifiers is None: identifiers = [] self.identifiers = identifiers @@ -363,6 +370,14 @@ class ClassificationProject(object): dataset_names = self.dataset_names for dataset_name in dataset_names: filename = os.path.join(self.project_dir, dataset_name+".h5") + if (self.data_dir is not None) and (not os.path.exists(filename)): + srcpath = os.path.join(self.data_dir, dataset_name+".h5") + try: + os.link(srcpath, filename) + logger.info("Created hardlink from {} to {}".format(srcpath, filename)) + except OSError: + os.symlink(srcpath, filename) + logger.info("Created symlink from {} to {}".format(srcpath, filename)) logger.info("Trying to load {} from {}".format(dataset_name, filename)) with h5py.File(filename) as hf: setattr(self, dataset_name, hf[dataset_name][:]) -- GitLab