diff --git a/test/test_toolkit.py b/test/test_toolkit.py
index ecd3aa244cd2269ae24a3386eaf990272d0c4871..1ebd7073d6667a480cf0e74ebe763f2ffa465a40 100644
--- a/test/test_toolkit.py
+++ b/test/test_toolkit.py
@@ -7,6 +7,7 @@ from keras.layers import GRU
 
 from KerasROOTClassification import ClassificationProject, ClassificationProjectRNN
 
+
 def create_dataset(path):
 
     # create example dataset with (low-weighted) noise added
@@ -50,6 +51,7 @@ def test_ClassificationProject(tmp_path):
         layers=3,
         nodes=128,
     )
+
     c.train(epochs=200)
     c.plot_all_inputs()
     c.plot_loss()
diff --git a/toolkit.py b/toolkit.py
index eb9234407031b53cff85267e961a3c660169559c..ec074bdf0c9e06e32de108a40fbb0e9641177488 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -42,6 +42,7 @@ import keras.initializers
 import keras.optimizers
 from keras.utils.vis_utils import model_to_dot
 from keras import backend as K
+import tensorflow as tf
 import matplotlib.pyplot as plt
 
 from .utils import WeightedRobustScaler, weighted_quantile, poisson_asimov_significance
@@ -65,6 +66,25 @@ if version_info[0] > 2:
     byteify = lambda input : input
 
 
+def set_session_threads(n_cpu=None):
+    "Set the number of threads based on OMP_NUM_THREADS or the given argument"
+
+    if n_cpu is None:
+        if os.environ.get('OMP_NUM_THREADS'):
+            n_cpu = int(os.environ.get('OMP_NUM_THREADS'))
+        else:
+            return
+
+    # not sure if this is the best configuration ...
+    config = tf.ConfigProto(intra_op_parallelism_threads=n_cpu,
+                            inter_op_parallelism_threads=1,
+                            allow_soft_placement=True,
+                            #log_device_placement=True,
+                            device_count = {'CPU': n_cpu})
+    session = tf.Session(config=config)
+    K.set_session(session)
+
+
 def load_from_dir(path):
     "Load a project and the options from a directory"
     try:
@@ -939,6 +959,8 @@ class ClassificationProject(object):
 
         self.total_epochs = self._read_info("epochs", 0)
 
+        set_session_threads()
+
         logger.info("Train model")
         if not self.balance_dataset:
             try: