diff --git a/pyproject.toml b/pyproject.toml index a33da0c..965657d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "iminuit", "jax>=0.4.28", "jaxlib>=0.4.28", - "jaxns==2.6.3", + "jaxns==2.6.7", "matplotlib", "nautilus-sampler", "numpy", diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index f802eaa..13be9d6 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from typing import Any, Callable, Literal + from jaxlib.xla_client import Device from prettytable import PrettyTable from elisa.infer.likelihood import Statistic @@ -644,7 +645,7 @@ def jaxns( s: int | None = None, k: int | None = None, c: int | None = None, - num_parallel_workers: int = 1, + devices: list[Device] | None = None, difficult_model: bool = False, parameter_estimation: bool = False, verbose: bool = False, @@ -671,8 +672,8 @@ def jaxns( Number of parallel Markov chains. The default is 30 * `D`, where `D` is the dimension of model parameters. It takes effect only for num_live_points=None. - num_parallel_workers : int, optional - Parallel workers number. The default is 1. + devices : list, optional + Devices to use. Defaults to all available devices. difficult_model : bool, optional If True, uses more robust default settings (`s` = 10 and `c` = 50 * `D`). It takes effect only for `num_live_points` = None, @@ -699,15 +700,13 @@ def jaxns( .. [1] `Phantom-Powered Nested Sampling `__ .. [2] `JAXNS API doc `__ """ - num_parallel_workers = int(num_parallel_workers) - constructor_kwargs = { 'max_samples': max_samples, 'num_live_points': num_live_points, 's': s, 'k': k, 'c': c, - 'num_parallel_workers': num_parallel_workers, + 'devices': devices, 'difficult_model': difficult_model, 'parameter_estimation': parameter_estimation, 'verbose': verbose, diff --git a/src/elisa/infer/nested_sampling.py b/src/elisa/infer/nested_sampling.py index d6579e5..0e590c4 100644 --- a/src/elisa/infer/nested_sampling.py +++ b/src/elisa/infer/nested_sampling.py @@ -247,7 +247,7 @@ def prior_model(): default_constructor_kwargs = dict( num_live_points=model.U_ndims * 25, - num_parallel_workers=1, + devices=jax.devices(), max_samples=1e4, ) default_termination_kwargs = dict(dlogZ=1e-4) @@ -272,8 +272,8 @@ def prior_model(): ) # TODO: check if this is necessary - # jit when num_parallel_workers is 1 - if self.constructor_kwargs['num_parallel_workers'] == 1: + # jit when running on single device + if len(default_ns.nested_sampler.devices) == 1: run_default_ns = jax.jit(default_ns) else: run_default_ns = default_ns