Skip to content
Snippets Groups Projects
Commit 58e11537 authored by Nikolai's avatar Nikolai
Browse files

Adding option data_dir to automatically reuse (via hard or symlink) previously created arrays

parent 2434384a
No related branches found
No related tags found
No related merge requests found
...@@ -75,6 +75,11 @@ class ClassificationProject(object): ...@@ -75,6 +75,11 @@ class ClassificationProject(object):
:param weight_expr: expression to weight the events in the loss function :param weight_expr: expression to weight the events in the loss function
:param data_dir: if given, load the data from a previous project with the given name
instead of creating it from trees. If the data is on the same
disk (and the filesystem supports it), hard links will be used,
otherwise symlinks.
:param identifiers: list of branches or expressions that uniquely :param identifiers: list of branches or expressions that uniquely
identify events. This is used to store the list of training identify events. This is used to store the list of training
events, such that they can be marked later on, for example when events, such that they can be marked later on, for example when
...@@ -152,6 +157,7 @@ class ClassificationProject(object): ...@@ -152,6 +157,7 @@ class ClassificationProject(object):
def _init_from_args(self, name, def _init_from_args(self, name,
signal_trees, bkg_trees, branches, weight_expr, signal_trees, bkg_trees, branches, weight_expr,
data_dir=None,
identifiers=None, identifiers=None,
selection=None, selection=None,
layers=3, layers=3,
...@@ -179,6 +185,7 @@ class ClassificationProject(object): ...@@ -179,6 +185,7 @@ class ClassificationProject(object):
self.branches = branches self.branches = branches
self.weight_expr = weight_expr self.weight_expr = weight_expr
self.selection = selection self.selection = selection
self.data_dir = data_dir
if identifiers is None: if identifiers is None:
identifiers = [] identifiers = []
self.identifiers = identifiers self.identifiers = identifiers
...@@ -363,6 +370,14 @@ class ClassificationProject(object): ...@@ -363,6 +370,14 @@ class ClassificationProject(object):
dataset_names = self.dataset_names dataset_names = self.dataset_names
for dataset_name in dataset_names: for dataset_name in dataset_names:
filename = os.path.join(self.project_dir, dataset_name+".h5") filename = os.path.join(self.project_dir, dataset_name+".h5")
if (self.data_dir is not None) and (not os.path.exists(filename)):
srcpath = os.path.join(self.data_dir, dataset_name+".h5")
try:
os.link(srcpath, filename)
logger.info("Created hardlink from {} to {}".format(srcpath, filename))
except OSError:
os.symlink(srcpath, filename)
logger.info("Created symlink from {} to {}".format(srcpath, filename))
logger.info("Trying to load {} from {}".format(dataset_name, filename)) logger.info("Trying to load {} from {}".format(dataset_name, filename))
with h5py.File(filename) as hf: with h5py.File(filename) as hf:
setattr(self, dataset_name, hf[dataset_name][:]) setattr(self, dataset_name, hf[dataset_name][:])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment