Skip to content
Snippets Groups Projects
Commit 10a63753 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

add mechanism to set number of threads via environment variablel

parent ac35ad64
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ from keras.layers import GRU
from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN
def create_dataset(path):
# create example dataset with (low-weighted) noise added
......@@ -50,6 +51,7 @@ def test_ClassificationProject(tmp_path):
layers=3,
nodes=128,
)
c.train(epochs=200)
c.plot_all_inputs()
c.plot_loss()
......
......@@ -42,6 +42,7 @@ import keras.initializers
import keras.optimizers
from keras.utils.vis_utils import model_to_dot
from keras import backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
from .utils import WeightedRobustScaler, weighted_quantile, poisson_asimov_significance
......@@ -65,6 +66,25 @@ if version_info[0] > 2:
byteify = lambda input : input
def set_session_threads(n_cpu=None):
"Set the number of threads based on OMP_NUM_THREADS or the given argument"
if n_cpu is None:
if os.environ.get('OMP_NUM_THREADS'):
n_cpu = int(os.environ.get('OMP_NUM_THREADS'))
else:
return
# not sure if this is the best configuration ...
config = tf.ConfigProto(intra_op_parallelism_threads=n_cpu,
inter_op_parallelism_threads=1,
allow_soft_placement=True,
#log_device_placement=True,
device_count = {'CPU': n_cpu})
session = tf.Session(config=config)
K.set_session(session)
def load_from_dir(path):
"Load a project and the options from a directory"
try:
......@@ -939,6 +959,8 @@ class ClassificationProject(object):
self.total_epochs = self._read_info("epochs", 0)
set_session_threads()
logger.info("Train model")
if not self.balance_dataset:
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment