Skip to content

Commit

Permalink
Update spmd.md for distributed/spmd
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Nov 11, 2023
1 parent 558a894 commit be48d90
Show file tree
Hide file tree
Showing 18 changed files with 33 additions and 46 deletions.
8 changes: 4 additions & 4 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou

### Simple Example & Sharding Aannotation API

Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L452)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.

```python
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor
Expand All @@ -46,7 +46,7 @@ import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh

# Enable XLA SPMD execution mode.
Expand Down Expand Up @@ -100,7 +100,7 @@ We derive a logical mesh based on this topology to create sub-groups of devices

![alt_text](assets/mesh_spmd2.png "image_tooltip")

We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028df4da388468fa9a41b1f98ea08bfce13b4c63/torch_xla/experimental/xla_sharding.py#L16). The axes of the logical Mesh can be named. Here is an example:
We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L17). The axes of the logical Mesh can be named. Here is an example:

```python
import torch_xla.runtime as xr
Expand Down Expand Up @@ -198,7 +198,7 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i
* `XLAShardedTensor` is a `torch.Tensor` subclass and works directly with native torch ops and `module.layers`. We use `__torch_dispatch__` to send `XLAShardedTensor` to the XLA backend. PyTorch/XLA retrieves attached sharding annotations to trace the graph and invokes XLA SPMDPartitioner.
* Internally, `XLAShardedTensor` (and its global\_tensor input) is backed by `XLATensor` with a special data structure holding references to the sharded device data.
* The sharded tensor after lazy execution may be gathered and materialized back to the host as global\_tensor when requested on the host (e.g., printing the value of the global tensor.
* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L111) to return the local shards on addressable devices as <code>List[[XLAShard](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L12)]</code>.
* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L117) to return the local shards on addressable devices as <code>List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]</code>.

There is also an ongoing effort to integrate <code>XLAShardedTensor</code> into <code>DistributedTensor</code> API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)].

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.experimental.spmd import xla_distribute_tensor
from torch_xla.distributed.spmd import xla_distribute_tensor

import unittest

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met
import unittest

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import test_xla_sharding_base


Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.test.test_utils as test_utils
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs

DEFAULT_KWARGS = dict(
batch_size=128,
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import torch_xla.utils.checkpoint as checkpoint
import torch_xla.utils.utils as xu
from torch_xla.experimental.spmd import Mesh
from torch_xla.distributed.spmd import Mesh
import torch.optim as optim
from torch import nn

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint.default_planner import (
create_default_local_save_plan,
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.experimetnal.spmd.xla_sharding as xs
from torch_xla.experimental.spmd import XLAShardedTensor
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch import nn
import torch_xla.core.xla_model as xm
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding_hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs

import test_xla_sharding_base

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_virtual_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
import test_xla_sharding_base


Expand Down
2 changes: 1 addition & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.utils as xu
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:

def _spmd_find_master_ip(current_worker_ip: str) -> str:
import torch_xla.runtime as xr
import torch_xla.experimental.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .xla_sharded_tensor import XLAShard, XLAShardedTensor
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded)
wrap_if_sharded, xla_patched_nn_linear_forward)
from .api import xla_distribute_tensor, xla_distribute_module

__all__ = [
"XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType",
"ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding",
"wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module"
"wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module",
"xla_patched_nn_linear_forward"
]
32 changes: 9 additions & 23 deletions torch_xla/distributed/spmd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,13 @@
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Placement, Replicate

log = logging.getLogger(__name__)

TORCH_XLA_INITIALIZED = False
try:
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
import torch_xla.runtime as xr # type:ignore[import]
from torch_xla.experimental.spmd import ( # type:ignore[import]
XLAShardedTensor,)
from torch_xla.experimental.spmd import ( # type:ignore[import]
mark_sharding, Mesh, ShardingType,
)
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
import torch_xla.runtime as xr # type:ignore[import]
from torch_xla.distributed.spmd import ( # type:ignore[import]
XLAShardedTensor, mark_sharding, Mesh, ShardingType,
)

TORCH_XLA_INITIALIZED = True
except ImportError as e:
log.warning(e.msg)
log = logging.getLogger(__name__)


# wrapper to check xla test requirements
Expand All @@ -36,14 +28,8 @@ def wrapper(
*args: Tuple[object],
**kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
if TORCH_XLA_INITIALIZED:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
return func(self, *args, **kwargs) # type: ignore[misc]
else:
raise ImportError(
"torch.distributed._tensor._xla API requires torch_xla package installation."
)
os.environ["XLA_USE_SPMD"] = "1"
return func(self, *args, **kwargs) # type: ignore[misc]

return wrapper

Expand Down Expand Up @@ -176,7 +162,7 @@ def xla_distribute_tensor(
assert (
sharding_type is None or sharding_type == ShardingType.REPLICATED
), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. "
"Clear the existing sharding annotation first, by callling torch_xla.experimental.spmd.clear_sharding API."
"Clear the existing sharding annotation first, by callling torch_xla.distributed.spmd.clear_sharding API."
global_tensor = tensor.global_tensor # type:ignore[attr-defined]
assert global_tensor is not None, "distributing a tensor should not be None"

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.spmd import XLAShardedTensor, XLAShard
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
import torch_xla.runtime as xr

import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dataclasses

import torch
import torch_xla.distributed.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint.planner import SavePlan
from typing import (
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/distributed_checkpoint/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch
import torch_xla
import torch_xla.distributed.spmd.xla_sharding as xs
import torch_xla.distributed.spmd as xs

from collections import ChainMap
from torch.distributed.checkpoint.default_planner import (
Expand Down

0 comments on commit be48d90

Please sign in to comment.