Skip to content

Commit

Permalink
Suppress the new warning about fork/threading when in a single_thread…
Browse files Browse the repository at this point in the history
…ed context
  • Loading branch information
rmjarvis committed Oct 15, 2023
1 parent e1a7099 commit 88237ad
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 55 deletions.
5 changes: 3 additions & 2 deletions galsim/config/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import inspect
from multiprocessing.managers import ListProxy, DictProxy

from .util import LoggerWrapper, SetDefaultExt, RetryIO, SafeManager
from .util import LoggerWrapper, SetDefaultExt, RetryIO, SafeManager, single_threaded
from .value import ParseValue
from .image import GetNObjForImage
from ..utilities import ensure_dir
Expand Down Expand Up @@ -70,7 +70,8 @@ class OutputManager(SafeManager): pass
OutputManager.register('list', list, ListProxy)
# Start up the output_manager
config['output_manager'] = OutputManager()
config['output_manager'].start()
with single_threaded():
config['output_manager'].start()

if 'extra_builder' not in config:
config['extra_builder'] = {}
Expand Down
10 changes: 8 additions & 2 deletions galsim/config/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from multiprocessing.managers import NamespaceProxy

from .util import LoggerWrapper, RemoveCurrent, GetRNG, GetLoggerProxy, get_cls_params
from .util import SafeManager, GetIndex, PropagateIndexKeyRNGNum
from .util import SafeManager, GetIndex, PropagateIndexKeyRNGNum, single_threaded
from .value import ParseValue, CheckAllParams, GetAllParams, SetDefaultIndex, _GetBoolValue
from .value import RegisterValueType
from ..errors import GalSimConfigError, GalSimConfigValueError, GalSimError
Expand Down Expand Up @@ -153,7 +153,13 @@ class InputManager(SafeManager): pass
InputManager.register(tag, init_func, proxy)
# Start up the input_manager
config['_input_manager'] = InputManager()
config['_input_manager'].start()
with single_threaded():
# Starting in python 3.12, there is a deprecation warning about using fork when
# a process is multithreaded. This can get triggered here by the start()
# function, even though I'm pretty sure this is completely safe.
# So at least until it is shown that this is a problem, just suppress
# this warning here by wrapping in single_threaded()
config['_input_manager'].start()

# Read all input fields provided and create the corresponding object
# with the parameters given in the config file.
Expand Down
3 changes: 2 additions & 1 deletion galsim/config/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ class LoggerManager(SafeManager): pass
logger_generator = SimpleGenerator(logger)
LoggerManager.register('logger', callable = logger_generator)
logger_manager = LoggerManager()
logger_manager.start()
with single_threaded():
logger_manager.start()
logger_proxy = logger_manager.logger()
else:
logger_proxy = None
Expand Down
50 changes: 28 additions & 22 deletions galsim/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from contextlib import contextmanager
import weakref
import os
import warnings
import numpy as np
import pstats
import math
Expand Down Expand Up @@ -1492,7 +1493,8 @@ def get_omp_threads():
os.environ[var] = "false"
return _galsim.GetOMPThreads()

class single_threaded:
@contextmanager
def single_threaded(*, num_threads=1):
"""A context manager that turns off (or down) OpenMP threading e.g. during multiprocessing.
Usage:
Expand Down Expand Up @@ -1520,29 +1522,33 @@ class single_threaded:
Parameters:
num_threads: The number of threads you want during the context [default: 1]
"""
def __init__(self, *, num_threads=1):
# Get the current number of threads here, so we can set it back when we're done.
self.orig_num_threads = get_omp_threads()
self.temp_num_threads = num_threads

# If threadpoolctl is installed, use that too, since it will set blas libraries to
# be single threaded too. This makes it so you don't need to set the environment
# variables OPENBLAS_NUM_THREAD=1 or MKL_NUM_THREADS=1, etc.
try:
import threadpoolctl
except ImportError:
self.tpl = None
else: # pragma: no cover (Not installed on GHA currently.)
self.tpl = threadpoolctl.threadpool_limits(num_threads)
# Get the current number of threads here, so we can set it back when we're done.
orig_num_threads = get_omp_threads()
temp_num_threads = num_threads

def __enter__(self):
set_omp_threads(self.temp_num_threads)
return self
# If threadpoolctl is installed, use that too, since it will set blas libraries to
# be single threaded too. This makes it so you don't need to set the environment
# variables OPENBLAS_NUM_THREAD=1 or MKL_NUM_THREADS=1, etc.
try:
import threadpoolctl
except ImportError:
tpl = None
else: # pragma: no cover (Not installed on GHA currently.)
tpl = threadpoolctl.threadpool_limits(num_threads)

set_omp_threads(temp_num_threads)
with warnings.catch_warnings():
# Starting in python 3.12, there is a deprecation warning about using fork when
# a process is multithreaded. Unfortunately, this applies even to processes that
# are currently single threaded, but used multi-threading previously.
# So if a user is doing something in an explicitly single-threaded context,
# suppress this DeprecationWarning.
warnings.filterwarnings("ignore", category=DeprecationWarning)
yield

def __exit__(self, type, value, traceback):
set_omp_threads(self.orig_num_threads)
if self.tpl is not None: # pragma: no cover
self.tpl.unregister()
set_omp_threads(orig_num_threads)
if tpl is not None: # pragma: no cover
tpl.unregister()



Expand Down
57 changes: 29 additions & 28 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,34 +1859,35 @@ def worker(input, output):

nproc = 4 # Each process will do 4 lists (typically)

# First make lists in the single process:
ref_lists = dict()
for seed in seeds:
list = generate_list(seed)
ref_lists[seed] = list

# Now do this with multiprocessing
# Put the seeds in a queue
task_queue = Queue()
for seed in seeds:
task_queue.put( [seed] )

# Run the tasks:
done_queue = Queue()
for k in range(nproc):
Process(target=worker, args=(task_queue, done_queue)).start()

# Check the results in the order they finished
for i in range(len(seeds)):
list, proc, args = done_queue.get()
seed = args[0]
np.testing.assert_array_equal(
list, ref_lists[seed],
err_msg="Random numbers are different when using multiprocessing")

# Stop the processes:
for k in range(nproc):
task_queue.put('STOP')
with single_threaded():
# First make lists in the single process:
ref_lists = dict()
for seed in seeds:
list = generate_list(seed)
ref_lists[seed] = list

# Now do this with multiprocessing
# Put the seeds in a queue
task_queue = Queue()
for seed in seeds:
task_queue.put( [seed] )

# Run the tasks:
done_queue = Queue()
for k in range(nproc):
Process(target=worker, args=(task_queue, done_queue)).start()

# Check the results in the order they finished
for i in range(len(seeds)):
list, proc, args = done_queue.get()
seed = args[0]
np.testing.assert_array_equal(
list, ref_lists[seed],
err_msg="Random numbers are different when using multiprocessing")

# Stop the processes:
for k in range(nproc):
task_queue.put('STOP')


@timer
Expand Down

0 comments on commit 88237ad

Please sign in to comment.