Skip to content

Commit

Permalink
Update spmd doc (#7096)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 22, 2024
1 parent f336317 commit 8a1ada8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
16 changes: 14 additions & 2 deletions docs/fsdpv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,20 @@ loss = output.sum()
loss.backward()
optim.step()
```
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. The autowrapping
feature will come in the future releases.
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. Here is an example to autowrao each `DecoderLayer`.
```python3
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
decoder_only_model.DecoderLayer
},
)
model = FSDPv2(
model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)
```

## Sharding output

Expand Down
23 changes: 13 additions & 10 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,11 @@ The SPMD API is general enough to express both data parallelism and model parall
num_devices = xr.global_runtime_device_count()

# Assume data is 4d and 0th dimension is the batch dimension
mesh_shape = (num_devices, 1, 1, 1)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H'))
partition_spec = range(num_devices)
mesh_shape = (num_devices,)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('Data'))
partition_spec = ('data', None, None, None)

# Shard the batch dimension
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
xs.mark_sharding(input_tensor, input_mesh, partition_spec)
```

Expand All @@ -424,9 +424,9 @@ PyTorch/XLA’s MpDeviceLoader supports input batch sharding, which also loads t
num_devices = xr.global_runtime_device_count()

# Assume data is 4d and 0th dimension is the batch dimension
mesh_shape = (num_devices, 1, 1, 1)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H'))
partition_spec = range(num_devices)
mesh_shape = (num_devices)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('Data'))
partition_spec = ('data', None, None, None)

# Use MpDeviceLoader to load data in background
train_loader = pl.MpDeviceLoader(
Expand All @@ -444,10 +444,13 @@ PyTorch’s FSDP is data parallel + sharded model parameters at 0th dimension. U

```python
for name, param in model.named_parameters():
shape = (num_devices,) + (1,) * (len(param.shape) - 1)
mesh = xs.Mesh(device_ids, shape)
xs.mark_sharding(param, mesh, range(len(param.shape)))
shape = (num_devices,)
mesh = xs.Mesh(device_ids, shape, ('fsdp'))
partition_spec = [None] * len(param.shape)
partition_spec[0] = 'fsdp'
xs.mark_sharding(param, mesh, partition_spec)
```
PyTorch/XLA also provided a convenient wrapper for the FSDP with SPMD, please take a look at this [user guide](https://github.com/pytorch/xla/blob/master/docs/fsdpv2.md).


### Running Resnet50 example with SPMD
Expand Down

0 comments on commit 8a1ada8

Please sign in to comment.