Skip to content

Commit

Permalink
Add default v5 flags
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Dec 14, 2023
1 parent d2e0676 commit 7e20976
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@ def _setup_libtpu_flags():
# and thus worse performance.
flags = _set_missing_flags(flags,
(('xla_latency_hiding_scheduler_rerun', '1'),))

if tpu.version() == 5:
default_v5_flags = {
# Enable async collectives
'xla_enable_async_all_gather': 'true',
'xla_enable_async_collective_permute': 'true',
# Limit compiler-injected rematerialization
'xla_jf_rematerialization_percent_shared_memory_limit': '10000',
# Enable collective matmul
'xla_jf_spmd_threshold_for_windowed_einsum_mib': '0'
# Enable async collective fusions
'xla_tpu_enable_async_collective_fusion': 'true',
'xla_tpu_enable_async_collective_fusion_fuse_all_gather': 'true',
'xla_tpu_enable_async_collective_fusion_multiple_steps': 'true',
# Disable net router
'xla_tpu_enable_net_router_in_all_gather': 'false',
# Disable experimental Reduce+Broadcast->ReduceWindow-Conv fusion
'xla_tpu_rwb_fusion': 'false',
}
flags = _set_missing_flags(flags, default_v5_flags.items())

os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags)


Expand Down

0 comments on commit 7e20976

Please sign in to comment.