From 10a63753017f8fd2fc3382633f0aa9d0cebd748c Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Fri, 2 Nov 2018 14:42:58 +0100 Subject: [PATCH] add mechanism to set number of threads via environment variablel --- test/test_toolkit.py | 2 ++ toolkit.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/test/test_toolkit.py b/test/test_toolkit.py index ecd3aa2..1ebd707 100644 --- a/test/test_toolkit.py +++ b/test/test_toolkit.py @@ -7,6 +7,7 @@ from keras.layers import GRU from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN + def create_dataset(path): # create example dataset with (low-weighted) noise added @@ -50,6 +51,7 @@ def test_ClassificationProject(tmp_path): layers=3, nodes=128, ) + c.train(epochs=200) c.plot_all_inputs() c.plot_loss() diff --git a/toolkit.py b/toolkit.py index eb92344..ec074bd 100755 --- a/toolkit.py +++ b/toolkit.py @@ -42,6 +42,7 @@ import keras.initializers import keras.optimizers from keras.utils.vis_utils import model_to_dot from keras import backend as K +import tensorflow as tf import matplotlib.pyplot as plt from .utils import WeightedRobustScaler, weighted_quantile, poisson_asimov_significance @@ -65,6 +66,25 @@ if version_info[0] > 2: byteify = lambda input : input +def set_session_threads(n_cpu=None): + "Set the number of threads based on OMP_NUM_THREADS or the given argument" + + if n_cpu is None: + if os.environ.get('OMP_NUM_THREADS'): + n_cpu = int(os.environ.get('OMP_NUM_THREADS')) + else: + return + + # not sure if this is the best configuration ... + config = tf.ConfigProto(intra_op_parallelism_threads=n_cpu, + inter_op_parallelism_threads=1, + allow_soft_placement=True, + #log_device_placement=True, + device_count = {'CPU': n_cpu}) + session = tf.Session(config=config) + K.set_session(session) + + def load_from_dir(path): "Load a project and the options from a directory" try: @@ -939,6 +959,8 @@ class ClassificationProject(object): self.total_epochs = self._read_info("epochs", 0) + set_session_threads() + logger.info("Train model") if not self.balance_dataset: try: -- GitLab