From 58e11537df8f476953d1406a151a804d43771996 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Mon, 11 Jun 2018 11:46:45 +0200
Subject: [PATCH] Adding option data_dir to automatically reuse (via hard or
 symlink) previously created arrays

---
 toolkit.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/toolkit.py b/toolkit.py
index 9605aa4..82391fa 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -75,6 +75,11 @@ class ClassificationProject(object):
 
     :param weight_expr: expression to weight the events in the loss function
 
+    :param data_dir: if given, load the data from a previous project with the given name
+                     instead of creating it from trees. If the data is on the same
+                     disk (and the filesystem supports it), hard links will be used,
+                     otherwise symlinks.
+
     :param identifiers: list of branches or expressions that uniquely
                         identify events. This is used to store the list of training
                         events, such that they can be marked later on, for example when
@@ -152,6 +157,7 @@ class ClassificationProject(object):
 
     def _init_from_args(self, name,
                         signal_trees, bkg_trees, branches, weight_expr,
+                        data_dir=None,
                         identifiers=None,
                         selection=None,
                         layers=3,
@@ -179,6 +185,7 @@ class ClassificationProject(object):
         self.branches = branches
         self.weight_expr = weight_expr
         self.selection = selection
+        self.data_dir = data_dir
         if identifiers is None:
             identifiers = []
         self.identifiers = identifiers
@@ -363,6 +370,14 @@ class ClassificationProject(object):
             dataset_names = self.dataset_names
         for dataset_name in dataset_names:
             filename = os.path.join(self.project_dir, dataset_name+".h5")
+            if (self.data_dir is not None) and (not os.path.exists(filename)):
+                srcpath = os.path.join(self.data_dir, dataset_name+".h5")
+                try:
+                    os.link(srcpath, filename)
+                    logger.info("Created hardlink from {} to {}".format(srcpath, filename))
+                except OSError:
+                    os.symlink(srcpath, filename)
+                    logger.info("Created symlink from {} to {}".format(srcpath, filename))
             logger.info("Trying to load {} from {}".format(dataset_name, filename))
             with h5py.File(filename) as hf:
                 setattr(self, dataset_name,  hf[dataset_name][:])
-- 
GitLab