Skip to content

Commit

Permalink
avoiding the error when resetting the platform
Browse files Browse the repository at this point in the history
  • Loading branch information
xiesl97 authored Sep 21, 2024
1 parent 649e3df commit ff38466
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/elisa/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ def set_jax_platform(platform: Literal['cpu', 'gpu', 'tpu'] | None = None):

jax.config.update('jax_platform_name', platform)

if platform == 'gpu':
# see https://jax.readthedocs.io/en/latest/gpu_performance_tips.html
xla_gpu_flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true'
)
xla_flags = os.getenv('XLA_FLAGS', '')
if xla_gpu_flags not in xla_flags:
os.environ['XLA_FLAGS'] = f'{xla_flags} {xla_gpu_flags}'
# if platform == 'gpu':
# # see https://jax.readthedocs.io/en/latest/gpu_performance_tips.html
# xla_gpu_flags = (
# '--xla_gpu_enable_triton_softmax_fusion=true '
# '--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
# '--xla_gpu_enable_latency_hiding_scheduler=true '
# '--xla_gpu_enable_highest_priority_async_stream=true'
# )
# xla_flags = os.getenv('XLA_FLAGS', '')
# if xla_gpu_flags not in xla_flags:
# os.environ['XLA_FLAGS'] = f'{xla_flags} {xla_gpu_flags}'


def set_cpu_cores(n: int) -> None:
Expand Down

0 comments on commit ff38466

Please sign in to comment.