Skip to content

Commit

Permalink
Update spmd.md with SPMD debug tool (#6358)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Jan 23, 2024
1 parent 07832b0 commit bc2ebed
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
Binary file added docs/assets/spmd_debug_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/spmd_debug_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,31 @@ XLA_USE_SPMD=1 python test/spmd/test_train_spmd_imagenet.py --fake_data --batch_
```

Note that I used a batch size 4 times as large since I am running it on a TPU v4 which has 4 TPU devices attached to it. You should see the throughput becomes roughly 4x the non-spmd run.

### SPMD Debugging Tool

We provide a `shard placement visualization debug tool` for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with `visualize_tensor_sharding` or `visualize_sharding`:
- Code snippet used `visualize_tensor_sharding` and visualization result:
```python
import rich

# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))

# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
```
![alt_text](assets/spmd_debug_1.png "visualize_tensor_sharding example on TPU v4-8(single-host)")
- Code snippet used `visualize_sharding` and visualization result:
```python
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
```
![alt_text](assets/spmd_debug_2.png "visualize_sharding example on TPU v4-8(single-host")

You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`.


0 comments on commit bc2ebed

Please sign in to comment.