diff --git a/docs/spmd.md b/docs/spmd.md index 5d6e554092d..0a245ec5473 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -470,6 +470,7 @@ Note that I used a batch size 4 times as large since I am running it on a TPU v4 We provide a `shard placement visualization debug tool` for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with `visualize_tensor_sharding` or `visualize_sharding`: - Code snippet used `visualize_tensor_sharding` and visualization result: + ```python import rich @@ -482,7 +483,9 @@ from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding generated_table = visualize_tensor_sharding(t, use_color=False) ``` ![alt_text](assets/spmd_debug_1.png "visualize_tensor_sharding example on TPU v4-8(single-host)") + - Code snippet used `visualize_sharding` and visualization result: + ```python from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,2]0,1,2,3}' @@ -498,11 +501,13 @@ We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RF PyTorch/XLA auto-sharding can be enabled by one of the following: - Setting envvar `XLA_SPMD_AUTO=1` - Calling the SPMD API in the beginning of your code: + ```python import torch_xla.runtime as xr xr.use_spmd(auto=True) ``` - Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`: + ```python import torch_xla.runtime as xr from torch.distributed._tensor import DeviceMesh, distribute_module