diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 7191b5d5bb92..d753f8f7c8f2 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -42,6 +42,15 @@ def _setup_libtpu_flags(): # and thus worse performance. flags = _set_missing_flags(flags, (('xla_latency_hiding_scheduler_rerun', '1'),)) + + if tpu.version() == 5: + default_v5_flags = { + # Enable async collectives + 'xla_enable_async_all_gather': 'true', + 'xla_enable_async_collective_permute': 'true', + } + flags = _set_missing_flags(flags, default_v5_flags.items()) + os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags)