From 732a1c7f13912c63c0570db3e263b319e4950407 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Mon, 5 Feb 2024 13:22:19 -0800 Subject: [PATCH] Add process group documentation for SPMD (#6469) --- docs/spmd.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/spmd.md b/docs/spmd.md index 334f31b5fac..c61ff0808a9 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -320,6 +320,28 @@ for step, data in enumerate(dataloader): print(f'Checkpoint taken at step {step}') ``` +### Process Groups +To use `torch.distributed` APIs such as distributed checkpointing, a process +group is required. In SPMD mode, the `xla` backend is not supported since the +compiler is responsible for all collectives. + +Instead, a CPU process group such as `gloo` must be used. On TPUs, the `xla://` +init_method is still supported to discover the master IP, global world size, +and host rank. An example initialization is below: + +```python +import torch.distributed as dist +# Import to register the `xla://` init_method +import torch_xla.distributed.xla_backend +import torch_xla.runtime as xr + +xr.use_spmd() + +# The `xla://` init_method will automatically discover master worker IP, rank, +# and global world size without requiring environment configuration on TPUs. +dist.init_process_group('gloo', init_method='xla://') +``` + ### Virtual Device Optimization PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.