Skip to content

Commit

Permalink
[Backport] Add default v5 flags (#6204)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Dec 19, 2023
1 parent 6860e30 commit 53b13bb
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def _setup_libtpu_flags():
# and thus worse performance.
flags = _set_missing_flags(flags,
(('xla_latency_hiding_scheduler_rerun', '1'),))

if tpu.version() == 5:
default_v5_flags = {
# Enable async collectives
'xla_enable_async_all_gather': 'true',
'xla_enable_async_collective_permute': 'true',
}
flags = _set_missing_flags(flags, default_v5_flags.items())

os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags)


Expand Down

0 comments on commit 53b13bb

Please sign in to comment.