Skip to content

Commit

Permalink
set default number of available devices to 4
Browse files Browse the repository at this point in the history
  • Loading branch information
xiesl97 authored Sep 21, 2024
1 parent e0452ed commit d0d7026
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/elisa/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,24 +258,23 @@ def get_parallel_number(n: int | None) -> int:
The available number of parallel processes.
"""
n_max = jax.local_device_count()
n_default = 4

if n is None:
return n_max
return n_default
else:
n = int(n)
if n <= 0:
raise ValueError(
f'number of parallel processes must be positive, got {n}'
)
raise ValueError(f"number of parallel processes must be positive, got {n}")

if n > n_max:
warnings.warn(
f'number of parallel processes ({n}) is more than the number of '
f'available devices ({jax.local_device_count()}), reset to '
f'{jax.local_device_count()}',
f"number of parallel processes ({n}) is more than the number of "
f"available devices ({n_max}), reset to "
f"{n_max}",
Warning,
)
n = jax.local_device_count()
n = n_max

return n

Expand Down

0 comments on commit d0d7026

Please sign in to comment.