Skip to content

Commit

Permalink
[SPMD] suppor DTensor API integration (pytorch#5776)
Browse files Browse the repository at this point in the history
* [SPMD] move SPMD package to torch_xla/experimental/spmd, introduce shadow xla DTensor API.

* support backward compatibility of the old imports

* Move spmd out of experimental

* Update spmd.md for distributed/spmd
  • Loading branch information
yeounoh authored and chunnienc committed Dec 14, 2023
1 parent 45669a0 commit 31447fc
Show file tree
Hide file tree
Showing 22 changed files with 1,125 additions and 830 deletions.
14 changes: 7 additions & 7 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou

### Simple Example & Sharding Aannotation API

Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L452)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.

```python
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor
Expand All @@ -46,8 +46,8 @@ import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh

# Enable XLA SPMD execution mode.
xr.use_spmd()
Expand Down Expand Up @@ -100,11 +100,11 @@ We derive a logical mesh based on this topology to create sub-groups of devices

![alt_text](assets/mesh_spmd2.png "image_tooltip")

We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028df4da388468fa9a41b1f98ea08bfce13b4c63/torch_xla/experimental/xla_sharding.py#L16). The axes of the logical Mesh can be named. Here is an example:
We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L17). The axes of the logical Mesh can be named. Here is an example:

```python
import torch_xla.runtime as xr
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.distributed.spmd import Mesh

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
Expand All @@ -130,7 +130,7 @@ In general, SPMD programs should create a single mesh and reuse it for all shard
Mesh nicely abstracts how the physical device mesh is constructed. Users can arrange devices in any shape and order using the logical mesh. However, one can define a more performant mesh based on the physical topology, especially when it involves Data Center Network (DCN) cross slice connections. HybridMesh creates a mesh which gives good performance out of the box for such multislice environments. It accepts ici\_mesh\_shape and dcn\_mesh\_shape which denote logical mesh shapes of inner and outer network.

```python
from torch_xla.experimental.xla_sharding import HybridMesh
from torch_xla.distributed.spmd import HybridMesh

# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
Expand Down Expand Up @@ -198,7 +198,7 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i
* `XLAShardedTensor` is a `torch.Tensor` subclass and works directly with native torch ops and `module.layers`. We use `__torch_dispatch__` to send `XLAShardedTensor` to the XLA backend. PyTorch/XLA retrieves attached sharding annotations to trace the graph and invokes XLA SPMDPartitioner.
* Internally, `XLAShardedTensor` (and its global\_tensor input) is backed by `XLATensor` with a special data structure holding references to the sharded device data.
* The sharded tensor after lazy execution may be gathered and materialized back to the host as global\_tensor when requested on the host (e.g., printing the value of the global tensor.
* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L111) to return the local shards on addressable devices as <code>List[[XLAShard](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L12)]</code>.
* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L117) to return the local shards on addressable devices as <code>List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]</code>.

There is also an ongoing effort to integrate <code>XLAShardedTensor</code> into <code>DistributedTensor</code> API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)].

Expand Down
81 changes: 81 additions & 0 deletions test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import sys

import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import DeviceMesh, Shard
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import xla_distribute_tensor

import unittest

import test_xla_sharding_base


class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_distribute_tensor(self):
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]

for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count,
3,
requires_grad=requires_grad,
device=xm.xla_device())
dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh,
shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
assert len(dist_tensor.sharding_spec) > 0

global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met
import unittest

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import test_xla_sharding_base


Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.test.test_utils as test_utils
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs

DEFAULT_KWARGS = dict(
batch_size=128,
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import torch_xla.utils.checkpoint as checkpoint
import torch_xla.utils.utils as xu
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.distributed.spmd import Mesh
import torch.optim as optim
from torch import nn

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint.default_planner import (
create_default_local_save_plan,
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch import nn
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding_hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs

import test_xla_sharding_base

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_virtual_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import test_xla_sharding_base


Expand Down
2 changes: 1 addition & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.utils as xu
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:

def _spmd_find_master_ip(current_worker_ip: str) -> str:
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace torch_xla {

class ShardingUtil {
public:
// This maps to `torch_xla.experimental.xla_sharding.ShardingType` enum type.
// This maps to `torch_xla.distributed.spmd.ShardingType` enum type.
enum ShardingType {
REPLICATED = 0,
MAXIMAL = 1,
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .xla_sharded_tensor import XLAShard, XLAShardedTensor
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward)
from .api import xla_distribute_tensor, xla_distribute_module

__all__ = [
"XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType",
"ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding",
"wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module",
"xla_patched_nn_linear_forward"
]
Loading

0 comments on commit 31447fc

Please sign in to comment.