diff --git a/docs/spmd.md b/docs/spmd.md
index 8b2f886880d..61afba530b0 100644
--- a/docs/spmd.md
+++ b/docs/spmd.md
@@ -33,7 +33,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou
### Simple Example & Sharding Aannotation API
-Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
+Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L452)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
```python
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor
@@ -46,8 +46,8 @@ import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
-import torch_xla.experimental.xla_sharding as xs
-from torch_xla.experimental.xla_sharding import Mesh
+import torch_xla.distributed.spmd as xs
+from torch_xla.distributed.spmd import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
@@ -100,11 +100,11 @@ We derive a logical mesh based on this topology to create sub-groups of devices
![alt_text](assets/mesh_spmd2.png "image_tooltip")
-We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028df4da388468fa9a41b1f98ea08bfce13b4c63/torch_xla/experimental/xla_sharding.py#L16). The axes of the logical Mesh can be named. Here is an example:
+We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L17). The axes of the logical Mesh can be named. Here is an example:
```python
import torch_xla.runtime as xr
-from torch_xla.experimental.xla_sharding import Mesh
+from torch_xla.distributed.spmd import Mesh
# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
@@ -130,7 +130,7 @@ In general, SPMD programs should create a single mesh and reuse it for all shard
Mesh nicely abstracts how the physical device mesh is constructed. Users can arrange devices in any shape and order using the logical mesh. However, one can define a more performant mesh based on the physical topology, especially when it involves Data Center Network (DCN) cross slice connections. HybridMesh creates a mesh which gives good performance out of the box for such multislice environments. It accepts ici\_mesh\_shape and dcn\_mesh\_shape which denote logical mesh shapes of inner and outer network.
```python
-from torch_xla.experimental.xla_sharding import HybridMesh
+from torch_xla.distributed.spmd import HybridMesh
# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
@@ -198,7 +198,7 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i
* `XLAShardedTensor` is a `torch.Tensor` subclass and works directly with native torch ops and `module.layers`. We use `__torch_dispatch__` to send `XLAShardedTensor` to the XLA backend. PyTorch/XLA retrieves attached sharding annotations to trace the graph and invokes XLA SPMDPartitioner.
* Internally, `XLAShardedTensor` (and its global\_tensor input) is backed by `XLATensor` with a special data structure holding references to the sharded device data.
* The sharded tensor after lazy execution may be gathered and materialized back to the host as global\_tensor when requested on the host (e.g., printing the value of the global tensor.
-* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L111) to return the local shards on addressable devices as List[[XLAShard](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L12)]
.
+* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L117) to return the local shards on addressable devices as List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]
.
There is also an ongoing effort to integrate XLAShardedTensor
into DistributedTensor
API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)].
diff --git a/test/spmd/test_dtensor_integration.py b/test/spmd/test_dtensor_integration.py
new file mode 100644
index 00000000000..552e698d352
--- /dev/null
+++ b/test/spmd/test_dtensor_integration.py
@@ -0,0 +1,81 @@
+import os
+import sys
+
+import torch
+from torch import nn
+import torch.optim as optim
+from torch.distributed._tensor import DeviceMesh, Shard
+import torch_xla
+import torch_xla.runtime as xr
+import torch_xla.core.xla_model as xm
+from torch_xla.distributed.spmd import xla_distribute_tensor
+
+import unittest
+
+import test_xla_sharding_base
+
+
+class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest):
+
+ @classmethod
+ def setUpClass(cls):
+ xr.use_spmd()
+ super().setUpClass()
+
+ def test_xla_distribute_tensor(self):
+ device_count = xr.global_runtime_device_count()
+ device_mesh = DeviceMesh("xla", list(range(device_count)))
+ shard_spec = [Shard(0)]
+
+ for requires_grad in [True, False]:
+ tensor_to_shard = torch.randn(
+ 3 * device_count,
+ 3,
+ requires_grad=requires_grad,
+ device=xm.xla_device())
+ dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh,
+ shard_spec)
+ # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
+ assert type(dist_tensor).__name__ == "XLAShardedTensor"
+ assert len(dist_tensor.sharding_spec) > 0
+
+ global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
+ self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
+ local_tensor = dist_tensor.local_shards[0].data
+ self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
+ if requires_grad:
+ self.assertTrue(dist_tensor.global_tensor.requires_grad)
+ self.assertTrue(dist_tensor.is_leaf)
+
+ def test_optimizer_step_with_sharding(self):
+ # Use simple linear model to test model parameter sharding
+ model = self.SimpleLinear().to(xm.xla_device())
+
+ # Running the same mark_sharding test with xla_distribute_tensor instead
+ device_count = xr.global_runtime_device_count()
+ device_mesh = DeviceMesh("xla", list(range(device_count)))
+ shard_spec = [Shard(0)]
+ xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
+ sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)
+
+ model.train()
+ optimizer = optim.SGD(model.parameters(), lr=0.1)
+ data = torch.randn(128, 128).to(xm.xla_device())
+ target = torch.zeros(128).to(xm.xla_device())
+ loss_fn = nn.CrossEntropyLoss()
+ for i in range(3):
+ optimizer.zero_grad()
+ output = model(data)
+ loss = loss_fn(output, target)
+ loss.backward()
+ optimizer.step()
+ xm.mark_step()
+ # Sharding is persisted across mark_step calls, and test if the sharded computation
+ # can repeat more than once without crashing.
+ self.assertEqual(sharding_spec,
+ torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
+
+
+if __name__ == '__main__':
+ test = unittest.main()
+ sys.exit(0 if test.result.wasSuccessful() else 1)
diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py
index 2874d5783bc..807a518d95b 100644
--- a/test/spmd/test_dynamo_spmd.py
+++ b/test/spmd/test_dynamo_spmd.py
@@ -6,7 +6,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met
import unittest
diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py
index 73323eddcc0..3ea2b2302b0 100644
--- a/test/spmd/test_spmd_graph_dump.py
+++ b/test/spmd/test_spmd_graph_dump.py
@@ -10,7 +10,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
import test_xla_sharding_base
diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py
index cde37989ef8..7d472da83de 100644
--- a/test/spmd/test_train_spmd_imagenet.py
+++ b/test/spmd/test_train_spmd_imagenet.py
@@ -84,7 +84,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.test.test_utils as test_utils
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
DEFAULT_KWARGS = dict(
batch_size=128,
diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py
index ad5294e5cfe..e08f361c42a 100644
--- a/test/spmd/test_train_spmd_linear_model.py
+++ b/test/spmd/test_train_spmd_linear_model.py
@@ -7,10 +7,10 @@
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
import torch_xla.utils.checkpoint as checkpoint
import torch_xla.utils.utils as xu
-from torch_xla.experimental.xla_sharding import Mesh
+from torch_xla.distributed.spmd import Mesh
import torch.optim as optim
from torch import nn
diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py
index 1081ce0b188..910b35f324b 100644
--- a/test/spmd/test_xla_distributed_checkpoint.py
+++ b/test/spmd/test_xla_distributed_checkpoint.py
@@ -14,7 +14,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
from torch.distributed.checkpoint.default_planner import (
create_default_local_save_plan,
diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py
index 1b128164a22..d41f89b5f6d 100644
--- a/test/spmd/test_xla_sharding.py
+++ b/test/spmd/test_xla_sharding.py
@@ -15,8 +15,8 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
-import torch_xla.experimental.xla_sharding as xs
-from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
+import torch_xla.distributed.spmd as xs
+from torch_xla.distributed.spmd import XLAShardedTensor
import test_xla_sharding_base
import torch_xla.core.xla_env_vars as xenv
diff --git a/test/spmd/test_xla_sharding_base.py b/test/spmd/test_xla_sharding_base.py
index 54067512ce2..57cbfe2a076 100644
--- a/test/spmd/test_xla_sharding_base.py
+++ b/test/spmd/test_xla_sharding_base.py
@@ -3,7 +3,7 @@
from torch import nn
import torch_xla.core.xla_model as xm
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
diff --git a/test/spmd/test_xla_sharding_hlo.py b/test/spmd/test_xla_sharding_hlo.py
index 3a39a906261..723d1c71fd3 100644
--- a/test/spmd/test_xla_sharding_hlo.py
+++ b/test/spmd/test_xla_sharding_hlo.py
@@ -9,7 +9,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
import test_xla_sharding_base
diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py
index ac304e7285d..d58797eb5ff 100644
--- a/test/spmd/test_xla_virtual_device.py
+++ b/test/spmd/test_xla_virtual_device.py
@@ -9,7 +9,7 @@
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
import test_xla_sharding_base
diff --git a/test/test_operations.py b/test/test_operations.py
index 3b6e9e0a8ef..4e8ebcedee4 100644
--- a/test/test_operations.py
+++ b/test/test_operations.py
@@ -36,7 +36,7 @@
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.utils as xu
diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py
index ec4e1626ec3..108ca7945a3 100644
--- a/torch_xla/_internal/tpu.py
+++ b/torch_xla/_internal/tpu.py
@@ -301,7 +301,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:
def _spmd_find_master_ip(current_worker_ip: str) -> str:
import torch_xla.runtime as xr
- import torch_xla.experimental.xla_sharding as xs
+ import torch_xla.distributed.spmd as xs
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h
index 3e600be6871..697f320f575 100644
--- a/torch_xla/csrc/xla_sharding_util.h
+++ b/torch_xla/csrc/xla_sharding_util.h
@@ -15,7 +15,7 @@ namespace torch_xla {
class ShardingUtil {
public:
- // This maps to `torch_xla.experimental.xla_sharding.ShardingType` enum type.
+ // This maps to `torch_xla.distributed.spmd.ShardingType` enum type.
enum ShardingType {
REPLICATED = 0,
MAXIMAL = 1,
diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py
new file mode 100644
index 00000000000..3cd50e1e7c0
--- /dev/null
+++ b/torch_xla/distributed/spmd/__init__.py
@@ -0,0 +1,12 @@
+from .xla_sharded_tensor import XLAShard, XLAShardedTensor
+from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
+ XLAPatchedLinear, mark_sharding, clear_sharding,
+ wrap_if_sharded, xla_patched_nn_linear_forward)
+from .api import xla_distribute_tensor, xla_distribute_module
+
+__all__ = [
+ "XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType",
+ "ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding",
+ "wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module",
+ "xla_patched_nn_linear_forward"
+]
diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py
new file mode 100644
index 00000000000..bea4415db57
--- /dev/null
+++ b/torch_xla/distributed/spmd/api.py
@@ -0,0 +1,182 @@
+import logging
+import os
+from functools import wraps
+from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
+
+import torch
+
+import torch.nn as nn
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from torch.distributed._tensor.placement_types import Placement, Replicate
+
+import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
+import torch_xla.runtime as xr # type:ignore[import]
+from torch_xla.distributed.spmd import ( # type:ignore[import]
+ XLAShardedTensor, mark_sharding, Mesh, ShardingType,
+)
+
+log = logging.getLogger(__name__)
+
+
+# wrapper to check xla test requirements
+def with_xla(func: Callable) -> Callable:
+ assert func is not None
+
+ @wraps(func) # pyre-ignore[6]
+ def wrapper(
+ self,
+ *args: Tuple[object],
+ **kwargs: Dict[str, Any] # type: ignore[misc]
+ ) -> None:
+ os.environ["XLA_USE_SPMD"] = "1"
+ return func(self, *args, **kwargs) # type: ignore[misc]
+
+ return wrapper
+
+
+@with_xla
+def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh":
+ """
+ Convert DTensor `dt_mesh` to XLAShardedTensor `partition_spec`.
+
+ Example (1x4 logical device mesh topology):
+ ```
+ dt_mesh = DeviceMesh("xla", [[1, 2, 3, 4]])
+ dt_mesh.shape
+ >> torch.Size([1, 4])
+
+ mesh = convert_to_xla_mesh(dt_mesh)
+ mesh_shape
+ >> [1, 4]
+ ```
+ """
+ assert dt_mesh.size() == xr.global_runtime_device_count()
+ return Mesh(dt_mesh.mesh.flatten(), tuple(dt_mesh.mesh.size()),
+ dt_mesh.mesh_dim_names)
+
+
+@with_xla
+def convert_to_xla_partition_spec(
+ tensor: torch.Tensor,
+ placements: Sequence[Placement]) -> Tuple[Union[Tuple, int, None]]:
+ """
+ Convert DTensor `placements` to XLAShardedTensor `partitoin_spec`.
+ This supports Shard and Replicate Placement types.
+
+ Example:
+ ```
+ # Mesh partitioning, 1/4-th of the input with replicated overlaps.
+ # The first input tensor dimension is sharded across the second mesh
+ # dimension, and the rest is replicated over the first mesh dimension.
+ t = torch.randn(4, 8, 8)
+ dt_mesh = DeviceMesh("xla", torch.arange(8).reshape(2,4))
+ placements = [Replicate(), Shard(0)]
+ my_dtensor = distribute_tensor(t, dt_mesh, placements)
+
+ # `placements = [Replicate(), Shard(0)]` describes sharding per mesh dim,
+ # and this is equivalent to `partition_spec = (1, None, None)` which is
+ # sharding per input tensor dimension.
+ partition_spec = convert_to_xla_partition_spec(t, placements)
+ >> (1, None, None)
+ ```
+ """
+ # per tensor dimension sharding
+ sharding_spec = [None] * len(tensor.shape)
+ for mesh_idx, spec in enumerate(placements):
+ if spec.is_shard(): # type:ignore[truthy-function]
+ # mesh_idx to tensor_idx (spec.dim)
+ tensor_idx = spec.dim # type:ignore[attr-defined]
+ sharding_spec[tensor_idx] = mesh_idx # type:ignore[call-overload]
+ elif spec.is_replicate():
+ # spec.dim is already set to None by default
+ continue
+ else:
+ raise ValueError(f"Unsupported placement type: {type(spec).__name__}")
+ return tuple(sharding_spec) # type:ignore[return-value]
+
+
+@with_xla
+def xla_distribute_tensor(
+ tensor: torch.Tensor,
+ device_mesh: DeviceMesh,
+ placements: Optional[Sequence[Placement]] = None,
+) -> "XLAShardedTensor":
+ """
+ Distribute a torch.Tensor to the `device_mesh` according to the `placements`
+ specified. The rank of `device_mesh` and `placements` must be the same.
+
+ Args:
+ tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
+ want to shard a tensor on a dimension that is not evenly divisible by
+ the number of devices in that mesh dimension, we use `torch.chunk`
+ semantic to shard the tensor and scatter the shards.
+ device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
+ tensor, if not specified, must be called under a DeviceMesh context
+ manager, default: None
+ placements (List[:class:`Placement`], optional): the placements that
+ describes how to place the tensor on DeviceMesh, must have the same
+ number of elements as `device_mesh.ndim`. If not specified, we will
+ by default replicate the tensor across the `device_mesh` from the
+ first rank of each dimension of the `device_mesh`.
+
+ Returns:
+ A :class:`XLAShardedTensor` object
+
+ .. note:: We return a XLAShardedTensor with a global view and access to local shards.
+ The successive ops would be programmed as if on a single-device and without calling
+ any explicit collective ops. The actual sharded computation on the sharding annotated tensor
+ happens lazily, is transparent to the user. In the future, we will introduce
+ a new DTensor type for this kind of programming-mode (single-controller) and return.
+ """
+ # device_mesh is not optional in xla_distribute_tensor
+ dt_mesh = device_mesh
+ assert dt_mesh.device_type == "xla"
+
+ # convert to XLA device mesh
+ xla_mesh = convert_to_xla_mesh(dt_mesh)
+ assert xla_mesh.mesh_shape == tuple(dt_mesh.mesh.size())
+
+ # convert tensor to the corresponding device type if it's not in that device type
+ if not tensor.is_meta:
+ tensor = tensor.to(dt_mesh.device_type)
+ # set default placements to replicated if not specified
+ if placements is None:
+ placements = [Replicate() for _ in range(dt_mesh.ndim)]
+ assert (len(placements) == dt_mesh.ndim
+ ), "`placements` must have the same length as `device_mesh.ndim`! "
+ f"Found placements length: {len(placements)}, and device_mesh.ndim: {dt_mesh.ndim}."
+ # convert placements to xla partition spec
+ partition_spec = convert_to_xla_partition_spec(tensor, placements)
+ assert len(tensor.shape) == len(
+ partition_spec
+ ), "`partition_spec` from `placements` must have the same length as `tensor.length`! "
+ f"Found tensor shape length: {len(tensor.shape)}, and partition_spec length: {len(partition_spec)}."
+
+ global_tensor = tensor
+ if type(tensor).__name__ == "DTensor":
+ raise ValueError(
+ "Cannot distribute a DTensor with local tensor on xla devices."
+ "The input tensor must be global.")
+ if type(tensor).__name__ == "XLAShardedTensor":
+ sharding_type = tensor.sharding_type # type:ignore[attr-defined]
+ assert (
+ sharding_type is None or sharding_type == ShardingType.REPLICATED
+ ), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. "
+ "Clear the existing sharding annotation first, by callling torch_xla.distributed.spmd.clear_sharding API."
+ global_tensor = tensor.global_tensor # type:ignore[attr-defined]
+ assert global_tensor is not None, "distributing a tensor should not be None"
+
+ # Annotates sharding and returns an XLAShardedTensor
+ xla_tensor = mark_sharding(global_tensor, xla_mesh, partition_spec)
+ return xla_tensor
+
+
+@with_xla
+def xla_distribute_module(
+ module: nn.Module,
+ device_mesh: Optional[DeviceMesh] = None,
+ partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
+ input_fn: Optional[Callable[..., None]] = None,
+ output_fn: Optional[Callable[..., None]] = None,
+) -> nn.Module:
+ raise NotImplementedError
diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py
new file mode 100644
index 00000000000..8e2e89f75f4
--- /dev/null
+++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py
@@ -0,0 +1,167 @@
+import torch
+from torch.utils._pytree import tree_map
+import torch_xla
+
+from dataclasses import dataclass
+from typing import List, Tuple, Iterator, Union
+import contextlib
+import collections
+
+
+@dataclass
+class XLAShard:
+ # A snapshot of the shard data from the time of XLAShard creation.
+ data: torch.Tensor
+
+ # The indices of the shard into the global tensor. If the tensor is replicated
+ # across local devices, the value of `indices` is Ellipsis. Otherwise, it is a
+ # list of the index slices across each dimension.
+ # The indices do not reflect padding, since the padding does not exist on the
+ # global tensor.
+ indices: Union[type(Ellipsis), List[slice]]
+
+ # The device this shard's data originated from.
+ shard_device: str
+
+ # The replica this shard belongs to, as determined by the sharding. The
+ # replica is determined differently for each sharding type:
+ # - TILED: Since the tensor isn't replicated, replica_id is always 0.
+ # - PARTIAL: replica_id is taken from the OpSharding and is a value in
+ # the range [0, num_replica).
+ # - REPLICATED: Since the tensor is fully replicated, replica_id is the
+ # device's global ordinal.
+ replica_id: int
+
+ @property
+ def unpadded_data(self) -> torch.Tensor:
+ ''' Returns a copy of `data` with padding removed '''
+ unpadded_indices = self.indices
+ # Replicated data has Ellipsis as indices
+ if self.indices != Ellipsis:
+ unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices]
+ return self.data[unpadded_indices]
+
+ @unpadded_data.setter
+ def unpadded_data(self, t: torch.Tensor):
+ unpadded_indices = self.indices
+ if self.indices != Ellipsis:
+ unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices]
+ self.data[unpadded_indices] = t
+
+
+@contextlib.contextmanager
+def no_dispatch() -> Iterator[None]:
+ guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
+ try:
+ yield
+ finally:
+ del guard
+
+
+class XLAShardedTensor(torch.Tensor):
+ """
+ A wrapper around `torch.Tensor` with sharding annotation
+ for XLA SPMD auto-sharding. The wrapped tensors are unwrapped
+ for IR tracing and converted to HLO graph with sharding annotations;
+ XLA SPMDPartitioner takes a pass, propagating and injecting collectives
+ to the graph before compilation.
+ """
+
+ # XLAShardedTensor behaves like a unpartitioned,
+ # combined tensor on the host machine. When user annotates,
+ # this is simply set to the input tensor. When an XLA partitioned
+ # output tensor returns (or sharding propagated intermediate tensors)
+ # as XLAShardedTensor, the backend gathers global data across devices
+ # and materialize and set `global_tensor` on the host; the actual device
+ # data still remain on individual device as sharded or replicated.
+ # Note: we should drop this reference, and force all gather on each access.
+ global_tensor: torch.Tensor
+ # A logical device topology, each element describes
+ # a number of devices in the corresponding axis.
+ # NOTE: we could use more specific device-rank mapping, e.g., ShardingSpec,
+ # if needed. The change shouldn't be difficult, or create another constructor.
+ mesh_shape: Tuple[int] # TODO: create a wrapper for named axes
+ # Specifies how each input rank is sharded (index to mesh_shape)
+ # or replicated (None). For example, we can shard an 8x10 tensor
+ # 4-way row-wise, and replicate column-wise.
+ # >> input = torch.randn(8, 10)
+ # >> mesh_shape = (4, 2)
+ # >> assert np.prod(mesh_shape) == len(xm.get_xla_supported_devices())
+ # >> partition_spec = (0, None)
+ # >> assert len(input.shape) == len(partition_spec)
+ partition_spec: Tuple[int, None]
+
+ __slots__ = ['global_tensor']
+
+ @staticmethod
+ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
+ # TODO(yeounoh) wrapper can take different arguments
+ r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
+ cls,
+ elem.size(),
+ strides=elem.stride(),
+ storage_offset=elem.storage_offset(),
+ dtype=elem.dtype,
+ layout=elem.layout,
+ device=elem.device,
+ requires_grad=kwargs.get("requires_grad", False))
+ r.global_tensor = elem.detach() if r.requires_grad else elem
+ return r
+
+ # Shards on the devices are materialized/available after the lazy
+ # execution of the partitioned HLO graph. Each XLAShard points
+ # to torch.Tensor. The shards represent a snapshot on CPU, detached
+ # from the global tensor. The shard data will contain any padding
+ # which results from the sharding.
+ @property
+ def local_shards(self) -> List[XLAShard]:
+ shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor)
+ replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
+ self.global_tensor)
+ zipped = zip(shards, replica_and_indices, devices)
+ return [
+ XLAShard(data, indices, dev, replica)
+ for data, (replica, indices), dev in zipped
+ ]
+
+ # Load the given list of local shards into the underlying tensor's data
+ # on the local devices.
+ def load_local_shards_(self, shards: List[XLAShard]):
+ data = [s.data for s in shards]
+ devices = [s.shard_device for s in shards]
+ torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices)
+
+ @property
+ def sharding_spec(self):
+ return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor)
+
+ @property
+ def sharding_type(self) -> 'ShardingType':
+ from torch_xla.distributed.spmd import ShardingType
+ sharding_type = torch_xla._XLAC._get_xla_sharding_type(self.global_tensor)
+ return ShardingType(sharding_type)
+
+ def __repr__(self):
+ return f"XLAShardedTensor({self.global_tensor})"
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ """
+ The dispatcher allows the unwrapped torch.Tensor to re-dispatched to the
+ `xla` backend as XlaTensor, and the XlaTensor with an associated sharding spec
+ to be received and wrapped as XLAShardedTensor.
+ """
+
+ def unwrap(elem):
+ return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
+
+ def wrap(elem):
+ return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
+
+ # no_dispatch is only needed if you use enable_python_mode.
+ # It prevents infinite recursion.
+ with no_dispatch():
+ # re-dispatch to C++
+ rs = tree_map(wrap,
+ func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+ return rs
diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
new file mode 100644
index 00000000000..d96531a5616
--- /dev/null
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -0,0 +1,642 @@
+import os
+from collections import OrderedDict, defaultdict
+from dataclasses import dataclass, field
+import torch
+import torch_xla
+import torch_xla.core.xla_model as xm
+from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
+import torch_xla.runtime as xr
+
+import numpy as np
+import functools
+import itertools
+from typing import Tuple, Union, List, Sequence, Any, Optional, Set
+from enum import IntEnum
+
+
+class Mesh:
+ """Describe the logical XLA device topology mesh and the underlying resources.
+
+ Args:
+ device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped
+ to an `mesh_shape` array, filling the elements using C-like index order.
+
+ mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape
+ of the device mesh, and each element describes the number of devices in
+ the corresponding axis.
+
+ axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions
+ of the `devices` argument. Its length should match the rank of `devices`.
+
+ Example:
+ —------------------------------
+ mesh_shape = (4, 2)
+ num_devices = len(xm.get_xla_supported_devices())
+ device_ids = np.array(range(num_devices))
+ mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
+ mesh.get_logical_mesh()
+ >> array([[0, 1],
+ [2, 3],
+ [4, 5],
+ [6, 7]])
+ mesh.shape()
+ >> OrderedDict([('x', 4), ('y', 2)])
+ """
+
+ device_ids: np.ndarray
+ mesh_shape: Tuple[int, ...]
+ axis_names: Tuple[str, ...]
+
+ def __init__(self,
+ device_ids: Union[np.ndarray, List],
+ mesh_shape: Tuple[int, ...],
+ axis_names: Tuple[str, ...] = None):
+ if not isinstance(device_ids, np.ndarray):
+ device_ids = np.array(device_ids)
+ assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
+ assert axis_names is None or (len(set(axis_names)) == len(axis_names))
+ assert (len(device_ids) == np.prod(mesh_shape))
+ assert len(device_ids) == len(np.unique(device_ids))
+ self.device_ids = device_ids
+ self.mesh_shape = mesh_shape
+ self.axis_names = axis_names
+ assert all(d < self.size() for d in device_ids)
+
+ def size(self):
+ return np.prod(self.mesh_shape)
+
+ def shape(self):
+ if self.axis_names is None:
+ return OrderedDict(
+ (dim, size) for dim, size in enumerate(self.mesh_shape))
+ return OrderedDict(
+ (name, size) for name, size in zip(self.axis_names, self.mesh_shape))
+
+ def get_logical_mesh(self):
+ return self.device_ids.reshape(self.mesh_shape)
+
+ def get_axis_name_idx(self, name: str) -> int:
+ if name not in self.axis_names:
+ return None
+ return self.axis_names.index(name)
+
+ @functools.lru_cache(maxsize=None)
+ def get_op_sharding(self,
+ partition_spec: Tuple,
+ flatten_opsharding=False) -> torch_xla._XLAC.OpSharding:
+ """
+ Return the OpSharding for the given partition spec. This is an expensive
+ operation as the mesh grows, so the value is cached for reuse.
+ """
+ partition_spec = _translate_named_partition_spec(self, partition_spec)
+ flat_specs = np.hstack([d for d in partition_spec])
+ specs = [d for d in flat_specs if d is not None]
+ assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
+ f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
+ assert len(specs) == len(np.unique(specs)), \
+ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."
+
+ tile_assignment = _get_tile_assignment(self, partition_spec)
+ if len(tile_assignment.shape) > len(partition_spec):
+ # Use partial replication for sharding a tensor over a higher-rank mesh
+ sharding_type = ShardingType.PARTIAL
+ else:
+ sharding_type = _get_sharding_type(partition_spec, self.size())
+ replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
+ group_assignment, replication_groups = _get_group_assignment(
+ sharding_type, tile_assignment, len(partition_spec), replicate_dims)
+
+ # If flatten_opsharding = True, return the flattened version of OpSharding
+ if flatten_opsharding:
+ return (tile_assignment.tolist(), group_assignment, replication_groups,
+ int(sharding_type))
+ else:
+ return torch_xla._XLAC.OpSharding(tile_assignment.tolist(),
+ group_assignment, replication_groups,
+ int(sharding_type))
+
+
+# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
+
+
+class HybridMesh(Mesh):
+ """Creates a hybrid device mesh of devices connected with ICI and DCN networks.
+ The shape of logical mesh should be ordered by increasing network-intensity
+ e.g. [replica, data, model] where mdl has the most network communication
+ requirements.
+
+ Args:
+ ici_mesh_shape: shape of the logical mesh for inner connected devices.
+ dcn_mesh_shape: shape of logical mesh for outer connected devices.
+
+ Example:
+ # This example is assuming 2 slices of v4-8.
+ ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
+ dcn_mesh_shape = (2, 1, 1)
+
+ mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
+ print(mesh.shape())
+ >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
+ """
+ ici_mesh_shape: Tuple[int, ...]
+ dcn_mesh_shape: Tuple[int, ...]
+
+ def __init__(self,
+ *,
+ ici_mesh_shape: Tuple[int, ...],
+ dcn_mesh_shape: Tuple[int, ...] = None,
+ axis_names: Tuple[str, ...] = None):
+ if dcn_mesh_shape == None:
+ dcn_mesh_shape = tuple([1] * len(ici_mesh_shape))
+ assert len(ici_mesh_shape) == len(dcn_mesh_shape)
+ mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
+ self.device_attributes = xr.global_runtime_device_attributes()
+ self.device_attributes.sort(
+ key=lambda attr: xm.parse_xla_device(attr['name'])[1])
+
+ if 'slice_index' in self.device_attributes[0] and np.prod(
+ dcn_mesh_shape) == 1:
+ raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice')
+ if 'slice_index' not in self.device_attributes[0] and np.prod(
+ dcn_mesh_shape) > 1:
+ raise ValueError('Invalid dcn_mesh_shape for single slice mesh')
+ self.ici_mesh_shape = ici_mesh_shape
+ self.dcn_mesh_shape = dcn_mesh_shape
+ if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[
+ 0]: # multislice
+ mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape,
+ self.dcn_mesh_shape)
+ else:
+ mesh = self._create_device_mesh(self.ici_mesh_shape)
+ device_ids = mesh.flatten()
+ super().__init__(device_ids, mesh_shape, axis_names)
+
+ # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
+ def _get_physical_tpu_mesh(self, devices: Sequence[int]) -> np.ndarray:
+ r"""Rearrange TPU devices in a slice into a physical mesh.
+
+ Args:
+ devices: A list of device logical ordinals in a TPU slice.
+
+ Returns:
+ A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
+ v2 and v3, global_z is instead cores_per_chip (i.e., 2).
+ """
+ assert xm.xla_device_hw(xm.xla_device()) == 'TPU'
+ # coords is a 3-dims tuple representing the device in physical mesh
+ device_coords = [self.device_attributes[d]['coords'] for d in devices]
+ dims = tuple(d + 1 for d in max(device_coords))
+ out = np.empty(dims, dtype=int)
+ for coords, d in zip(device_coords, devices):
+ out[coords[0], coords[1], coords[2]] = d
+ return out
+
+ # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
+ def _create_device_mesh_for_nd_torus(
+ self, physical_mesh: np.ndarray,
+ mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
+ """Assigns logical parallelism axes to physical axes of an N-D torus network.
+
+ Given logical parallelism axes with sizes in `mesh_shape` and devices in an
+ N-dimensional torus network represented by `physical_mesh`, maps each logical
+ axis to one or more physical axes. Prefer to map more-performance-sensitive
+ logical axes to larger numbers of physical axes to maximize the bandwidth
+ available to them. Also prefer to assign logical axes to multiple physical
+ axes of the same size (e.g., a 2D square) rather than multiple physical axes
+ of different sizes when possible.
+
+ Note that this routine will never split a physical axis over more than one
+ logical axis (which would reduce total usable bandwidth but may sometimes be
+ desired anyway). As a result, it will error out in cases where this is
+ necessary to produce a valid mapping.
+
+ Let's use a concrete example to explain the concepts and considerations.
+
+ As an example, suppose the logical mesh is [data, model], for data and model
+ parallelism respectively. Also suppose that data parallelism is less
+ performance sensitive than model parallelism. Consider a 3D TPU pod slice of
+ shape 4x4x16, represented by a physical mesh of shape (4, 4, 16).
+
+ A TPU pod slice has equal bandwidth along all axes with wraparound links, but
+ a 2D plane of size 4x4 may have faster XLA collective implementations than a
+ non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want
+ the more performance sensitive `model` axis to be mapped to the 4x4 XY plane.
+
+ Args:
+ physical_mesh: a np.ndarray of devices in the shape of the N-D torus
+ physical topology.
+ mesh_shape: shape of the logical mesh (size of the various logical
+ parallelism axes), with axes ordered by increasing network intensity.
+
+ Returns:
+ An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
+ each logical parallelism axis mapped to one or more physical mesh axes.
+ The axis assignment (a list of length num_logical_axes, whose elements
+ are tuples representing physical axis indices).
+ """
+ # Remaining physical axes to be assigned to logical axes.
+ assignable_physical_mesh = list(physical_mesh.shape)
+ # Map each logical axis to a subset of physical axes.
+ assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
+ # Assign logical axes from highest network intensity to lowest.
+ # `mesh_shape` is assumed to ordered by lowest network intensity first, so
+ # reverse it first.
+ # Assigns devices to 2D or 3D logical mesh.
+ for logical_axis_index, logical_axis_size in reversed(
+ list(enumerate(mesh_shape))):
+ for num_axes in range(3, 0, -1):
+ # map a combination of devices in physical axes to the logical axis.
+ axes = itertools.combinations(assignable_physical_mesh, num_axes)
+ indices = itertools.combinations(
+ range(len(assignable_physical_mesh)), num_axes)
+ for c_axes, c_indices in zip(axes, indices):
+ if np.product(c_axes) == logical_axis_size:
+ assignment[logical_axis_index] = c_indices
+ # Zero the assigned physical axes.
+ assignable_physical_mesh = [
+ 0 if i in c_indices else v
+ for i, v in enumerate(assignable_physical_mesh)
+ ]
+ break
+ if assignment[logical_axis_index]:
+ # We already found an assignment from one candidate above.
+ break
+ else:
+ # If the num_axes for loop did not break, i.e. none of the candidates work
+ # goto here with this while-else construct.
+ if logical_axis_size > 1:
+ raise NotImplementedError(
+ 'Failed to find assignment for logical_axis_index'
+ f' {logical_axis_index} of size {logical_axis_size} with remaining'
+ f' assignable mesh {assignable_physical_mesh}. The size of each'
+ ' axis in your logical mesh must be equal to the product of'
+ ' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
+ ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
+ )
+ # Flatten the assignment
+ transpose: List[int] = []
+ for x in assignment:
+ for y in x:
+ transpose.append(int(y))
+ return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment
+
+ def _create_device_mesh(self,
+ mesh_shape: Sequence[int],
+ devices: Sequence[Any] = None) -> Sequence[int]:
+ """Creates a performant device mesh.
+
+ Args:
+ mesh_shape: shape of logical mesh, ordered by increasing network-intensity
+ e.g. [replica, data, mdl] where mdl has the most network communication
+ requirements.
+ devices: optionally, the devices to construct a mesh for.
+
+ Returns:
+ A np.ndarray of devices with mesh_shape as its shape.
+ """
+
+ if devices is None:
+ devices = np.arange(xr.global_runtime_device_count())
+ if np.prod(mesh_shape) != len(devices):
+ raise ValueError(
+ f'Number of devices {len(devices)} must equal the product '
+ f'of mesh_shape {mesh_shape}')
+ physical_mesh = self._get_physical_tpu_mesh(devices)
+ device_mesh, assignment = self._create_device_mesh_for_nd_torus(
+ physical_mesh, mesh_shape)
+ return device_mesh
+
+ # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
+ def _create_hybrid_device_mesh(
+ self, ici_mesh_shape: Sequence[int],
+ dcn_mesh_shape: Sequence[int]) -> Sequence[int]:
+ """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
+
+ Args:
+ ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
+ by increasing network intensity, e.g. [replica, data, mdl] where mdl has
+ the most network communication requirements.
+ dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
+ in the same order as mesh_shape.
+
+ Returns:
+ A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
+ that can be fed into HybridMesh for hybrid parallelism.
+ """
+ granule_dict = defaultdict(list)
+ for d, dev in enumerate(self.device_attributes):
+ granule_dict[dev['slice_index']].append(d)
+ # sorts devices based on slice_index.
+ granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
+ if np.prod(dcn_mesh_shape) != len(granules):
+ raise ValueError(
+ f'Number of slices {len(granules)} must equal the product of '
+ f'dcn_mesh_shape {dcn_mesh_shape}')
+ # creates a seperate internal mesh for each slice.
+ per_granule_meshes = [
+ self._create_device_mesh(ici_mesh_shape, granule)
+ for granule in granules
+ ]
+ granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
+ blocks = np.vectorize(
+ lambda i: per_granule_meshes[i], otypes=[object])(
+ granule_mesh)
+ device_mesh = np.block(blocks.tolist())
+ return device_mesh
+
+
+class ShardingType(IntEnum):
+ # ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX)
+ REPLICATED = 0
+ MAXIMAL = 1
+ TUPLE = 2
+ TILED = 3
+ MANUAL = 4
+ PARTIAL = 5
+
+
+def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
+ num_devices: int) -> ShardingType:
+ sharding_type = ShardingType.TILED
+ if num_devices == 1:
+ sharding_type = ShardingType.MAXIMAL
+ elif all(d is None for d in partition_spec):
+ sharding_type = ShardingType.REPLICATED
+ elif any(d is None for d in partition_spec):
+ sharding_type = ShardingType.PARTIAL
+ return sharding_type
+
+
+def _get_tile_assignment(
+ mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int,
+ None]]) -> np.ndarray:
+ """
+ Permute the given mesh to create the tile assignment based on the partition
+ spec. Returns the tiling assignment as a numpy ndarray.
+
+ If the input partition_spec combines multiple logical mesh axes over a single
+ tensor axis, the resulting tiling assignment will combine the specified axes
+ into a single axis.
+ """
+ # Flatten the partition spec and ensure that it is fully specified over the
+ # mesh for permutation.
+ tiled_dims = [x for x in partition_spec if x is not None]
+ permutation = np.hstack(tiled_dims).tolist() if tiled_dims else []
+ missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation))
+ tile_assignment = mesh.get_logical_mesh().transpose(permutation +
+ missing_axes)
+
+ # For any tuples in the partition_spec, the grouped axes will be adjacent
+ # after the permutation. Combine these dimensions into a single axis.
+ for i, spec in enumerate(tiled_dims):
+ if isinstance(spec, tuple):
+ shape = tile_assignment.shape
+ tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) +
+ shape[i + len(spec):])
+
+ return tile_assignment
+
+
+# Produce group assignment for partial replication. Partial replication tiles
+# groups (a.k.a. sub-groups) where the shards are fully replicated within each
+# sub-group. `replication_groups` is a list of groups as lists, where each group
+# contains the participating device IDs. `group_assignment` describes the group
+# placement and the overall mesh, where each element is the group ID.
+# The tile_assignment should be the result of `_get_tile_assignment` so that all
+# tiled dimensions are in the first axes and replicated dimensions are in the
+# remaining axes.
+def _get_group_assignment(sharding_type: ShardingType,
+ tile_assignment: np.ndarray, tensor_rank: int,
+ replicate_dims: Set[int]) -> Tuple[List, List]:
+ group_assignment = list()
+ replication_groups = list()
+ if sharding_type is ShardingType.PARTIAL:
+ # Shard across groups and replicate within subgroups; replicated dims
+ # will be used to group replication devices.
+ tile_shape = tile_assignment.shape
+ # When creating the tile assignment, the mesh is permuted so that the first
+ # few axes are used for tiling.
+ tile_dims = range(tensor_rank - len(replicate_dims))
+ group_list = [tile_assignment]
+ for d in tile_dims:
+ _group_list = list()
+ for group_members in group_list:
+ _group_list += np.split(group_members, tile_shape[d], d)
+ group_list = _group_list
+ replication_groups = [group.flatten().tolist() for group in group_list]
+
+ mesh_axis = itertools.count()
+ group_tile_shape = [
+ 1 if d in replicate_dims else tile_shape[next(mesh_axis)]
+ for d in range(tensor_rank)
+ ]
+ group_assignment = np.arange(len(replication_groups)).reshape(
+ tuple(group_tile_shape)).tolist()
+ return group_assignment, replication_groups
+
+
+def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
+ _partition_spec = list()
+ for p in partition_spec:
+ if type(p) is tuple:
+ assert not any(type(x) is tuple
+ for x in p), 'Partition spec cannot contain nested tuples'
+ _partition_spec.append(_translate_named_partition_spec(mesh, p))
+ elif (p is None) or (type(p) is int):
+ _partition_spec.append(p)
+ elif type(p) is str:
+ idx = mesh.get_axis_name_idx(p)
+ if idx is None:
+ raise ValueError(f"Axis name {p} is not defined in the given mesh")
+ _partition_spec.append(idx)
+ else:
+ raise ValueError(
+ f"Spec type {type(p)} is not supported in partition spec")
+ return tuple(_partition_spec)
+
+
+@xr.requires_pjrt
+def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
+ mesh: Mesh,
+ partition_spec: Tuple[Union[Tuple, int, str, None]],
+ use_dynamo_custom_op: bool = False) -> XLAShardedTensor:
+ """
+ Annotates the tensor provided with XLA partition spec. Internally,
+ it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
+ Args:
+ t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec.
+
+ mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
+
+ partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or
+ `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str.
+ This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
+ When a tuple is specified, the corresponding input tensor axis will be sharded along all
+ logical axes in the tuple. Note that the order the mesh axes are specified in the tuple
+ will impact the resulting sharding.
+ For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
+ >> input = torch.randn(8, 10)
+ >> mesh_shape = (4, 2)
+ >> partition_spec = (0, None)
+
+ dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding
+ to make itself recognizeable and traceable by dynamo.
+
+ Examples
+ —------------------------------
+ mesh_shape = (4, 2)
+ num_devices = xr.global_runtime_device_count()
+ device_ids = np.array(range(num_devices))
+ mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
+
+ # 4-way data parallel
+ input = torch.randn(8, 32).to(xm.xla_device())
+ xs.mark_sharding(input, mesh, (0, None))
+
+ # 2-way model parallel
+ linear = nn.Linear(32, 10).to(xm.xla_device())
+ xs.mark_sharding(linear.weight, mesh, (None, 1))
+ """
+ num_devices = xr.global_runtime_device_count()
+ assert num_devices > 0, "This requires XLA supported device(s)."
+ assert mesh.size() == num_devices, \
+ f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
+ # We only allow fully specified `partition_spec` to be applicable, as opposed
+ # to filling in the unspecified replicated dims. Fully specified `partiion_spec`
+ # should be of the same rank as `t`. This is to support partial replication
+ # where the group assignment may vary with different input ranks.
+ assert len(t.shape) == len(partition_spec), \
+ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
+
+ if use_dynamo_custom_op:
+ tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(
+ partition_spec, flatten_opsharding=True)
+
+ if isinstance(t, XLAShardedTensor):
+ torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(
+ t.global_tensor, tile_assignment, group_assignment,
+ replication_groups, sharding_type)
+ return t
+ else:
+ torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(
+ t, tile_assignment, group_assignment, replication_groups,
+ sharding_type)
+ return XLAShardedTensor(t)
+ else:
+ op_sharding = mesh.get_op_sharding(partition_spec)
+
+ if isinstance(t, XLAShardedTensor):
+ torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
+ return t
+ else:
+ torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
+ return XLAShardedTensor(t)
+
+
+def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
+ """Clear sharding annotation from the input tensor and return a `cpu` casted tensor."""
+ torch_xla._XLAC._xla_clear_sharding(t)
+ if isinstance(t, XLAShardedTensor):
+ return t.global_tensor
+ return t
+
+
+def wrap_if_sharded(x: Any) -> Any:
+ """
+ If the input is a sharded tensor, return an XLAShardedTensor wrapping it.
+ Otherwise, returns the input.
+ """
+ if (isinstance(x, torch.Tensor) and not isinstance(x, XLAShardedTensor) and
+ x.device.type == 'xla' and
+ torch_xla._XLAC._get_xla_sharding_type(x) is not None):
+ return XLAShardedTensor(x)
+ return x
+
+
+@dataclass
+class ShardingSpec:
+ mesh: Mesh
+ partition_spec: Tuple[Union[int, None]]
+ minibatch: Optional[bool] = False
+
+ # Derived fields
+ _tile_assignment: List[int] = field(init=False)
+ _group_assignment: List[int] = field(init=False)
+ _replication_groups: List[int] = field(init=False)
+ _sharding_type: ShardingType = field(init=False)
+
+ @xr.requires_pjrt
+ def __post_init__(self):
+ mesh = self.mesh
+ partition_spec = _translate_named_partition_spec(mesh, self.partition_spec)
+ tile_assignment = _get_tile_assignment(mesh, partition_spec)
+ self._tile_assignment = tile_assignment.tolist()
+ self._sharding_type = _get_sharding_type(partition_spec,
+ xr.global_runtime_device_count())
+ replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
+ self._group_assignment, self._replication_groups = _get_group_assignment(
+ self._sharding_type, tile_assignment, len(partition_spec),
+ replicate_dims)
+
+ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
+ """
+ Create an XlaShardingSpec for the given tensor. If the tensor is
+ incompatible with the ShardingSpec, returns None.
+ """
+ if not self.can_apply(t):
+ return None
+ return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
+ self._group_assignment,
+ self._replication_groups,
+ int(self._sharding_type),
+ self.minibatch)
+
+ def can_apply(self, t: torch.Tensor) -> bool:
+ """
+ Test whether the ShardingSpec is compatible with the given torch.Tensor.
+ """
+ return len(t.shape) == len(self.partition_spec)
+
+ def apply(self, t: torch.Tensor):
+ # TODO(yeounoh) use virtual device interface when available.
+ assert (t.device == xm.xla_device())
+ mark_sharding(t, self.mesh, self.partition_spec)
+
+
+class XLAPatchedLinear(torch.autograd.Function):
+ """
+ A patched version of `torch.nn.functional.linear` that uses einsum instead
+ of torch.matmul which will flatten the tensors to 2D and collide the sharded
+ dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
+ to propagate the sharding annotation.
+
+ TODO (alanwaketan): Let's patch it on the dispatcher level.
+ """
+
+ @staticmethod
+ def forward(ctx, input, weight, bias=None):
+ # bias is an optional argument
+ ctx.save_for_backward(input, weight, bias)
+ with torch.no_grad():
+ product = torch.einsum('...n,mn->...m', input, weight)
+ if bias is None:
+ return product
+ return product + bias
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight, bias = ctx.saved_tensors
+ grad_input = grad_weight = grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
+ if ctx.needs_input_grad[1]:
+ grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
+ if bias is not None and ctx.needs_input_grad[2]:
+ grad_bias = torch.einsum('...m->m', grad_output)
+
+ return grad_input, grad_weight, grad_bias
+
+
+def xla_patched_nn_linear_forward(m, input):
+ return XLAPatchedLinear.apply(input, m.weight, m.bias)
diff --git a/torch_xla/experimental/distributed_checkpoint/_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py
index b49e7419dcd..6ab2da163ac 100644
--- a/torch_xla/experimental/distributed_checkpoint/_helpers.py
+++ b/torch_xla/experimental/distributed_checkpoint/_helpers.py
@@ -5,7 +5,7 @@
import dataclasses
import torch
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
from torch.distributed.checkpoint.planner import SavePlan
from typing import (
@@ -23,7 +23,7 @@
)
from torch.distributed.checkpoint.metadata import (MetadataIndex,
STATE_DICT_TYPE)
-from torch_xla.experimental.xla_sharding import XLAShardedTensor, ShardingType
+from torch_xla.distributed.spmd import XLAShardedTensor, ShardingType
from torch.utils._pytree import tree_map
PATH_ITEM = Union[str, int]
diff --git a/torch_xla/experimental/distributed_checkpoint/planners.py b/torch_xla/experimental/distributed_checkpoint/planners.py
index fbf466ff28a..6810ddb56a3 100644
--- a/torch_xla/experimental/distributed_checkpoint/planners.py
+++ b/torch_xla/experimental/distributed_checkpoint/planners.py
@@ -4,7 +4,7 @@
import numpy as np
import torch
import torch_xla
-import torch_xla.experimental.xla_sharding as xs
+import torch_xla.distributed.spmd as xs
from collections import ChainMap
from torch.distributed.checkpoint.default_planner import (
@@ -34,7 +34,7 @@
)
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.utils._pytree import tree_map
-from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard
+from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
from torch_xla.experimental.distributed_checkpoint._helpers import (
FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor,
set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards)
diff --git a/torch_xla/experimental/xla_sharded_tensor.py b/torch_xla/experimental/xla_sharded_tensor.py
index 1c3eaf34916..a1f2e71d98b 100644
--- a/torch_xla/experimental/xla_sharded_tensor.py
+++ b/torch_xla/experimental/xla_sharded_tensor.py
@@ -1,167 +1,10 @@
-import torch
-from torch.utils._pytree import tree_map
-import torch_xla
+# Keep this for backward compatibility.
+# TODO(yeounoh) remove after 2.2 release.
+import warnings
-from dataclasses import dataclass
-from typing import List, Tuple, Iterator, Union
-import contextlib
-import collections
+warnings.warn(
+ "Importing from `torch_xla.experimental.xla_sharded_tensor` will be deprecated "
+ "after 2.2 release. Please use `torch_xla.distributed.spmd` "
+ "instead.", DeprecationWarning, 2)
-
-@dataclass
-class XLAShard:
- # A snapshot of the shard data from the time of XLAShard creation.
- data: torch.Tensor
-
- # The indices of the shard into the global tensor. If the tensor is replicated
- # across local devices, the value of `indices` is Ellipsis. Otherwise, it is a
- # list of the index slices across each dimension.
- # The indices do not reflect padding, since the padding does not exist on the
- # global tensor.
- indices: Union[type(Ellipsis), List[slice]]
-
- # The device this shard's data originated from.
- shard_device: str
-
- # The replica this shard belongs to, as determined by the sharding. The
- # replica is determined differently for each sharding type:
- # - TILED: Since the tensor isn't replicated, replica_id is always 0.
- # - PARTIAL: replica_id is taken from the OpSharding and is a value in
- # the range [0, num_replica).
- # - REPLICATED: Since the tensor is fully replicated, replica_id is the
- # device's global ordinal.
- replica_id: int
-
- @property
- def unpadded_data(self) -> torch.Tensor:
- ''' Returns a copy of `data` with padding removed '''
- unpadded_indices = self.indices
- # Replicated data has Ellipsis as indices
- if self.indices != Ellipsis:
- unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices]
- return self.data[unpadded_indices]
-
- @unpadded_data.setter
- def unpadded_data(self, t: torch.Tensor):
- unpadded_indices = self.indices
- if self.indices != Ellipsis:
- unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices]
- self.data[unpadded_indices] = t
-
-
-@contextlib.contextmanager
-def no_dispatch() -> Iterator[None]:
- guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
- try:
- yield
- finally:
- del guard
-
-
-class XLAShardedTensor(torch.Tensor):
- """
- A wrapper around `torch.Tensor` with sharding annotation
- for XLA SPMD auto-sharding. The wrapped tensors are unwrapped
- for IR tracing and converted to HLO graph with sharding annotations;
- XLA SPMDPartitioner takes a pass, propagating and injecting collectives
- to the graph before compilation.
- """
-
- # XLAShardedTensor behaves like a unpartitioned,
- # combined tensor on the host machine. When user annotates,
- # this is simply set to the input tensor. When an XLA partitioned
- # output tensor returns (or sharding propagated intermediate tensors)
- # as XLAShardedTensor, the backend gathers global data across devices
- # and materialize and set `global_tensor` on the host; the actual device
- # data still remain on individual device as sharded or replicated.
- # Note: we should drop this reference, and force all gather on each access.
- global_tensor: torch.Tensor
- # A logical device topology, each element describes
- # a number of devices in the corresponding axis.
- # NOTE: we could use more specific device-rank mapping, e.g., ShardingSpec,
- # if needed. The change shouldn't be difficult, or create another constructor.
- mesh_shape: Tuple[int] # TODO: create a wrapper for named axes
- # Specifies how each input rank is sharded (index to mesh_shape)
- # or replicated (None). For example, we can shard an 8x10 tensor
- # 4-way row-wise, and replicate column-wise.
- # >> input = torch.randn(8, 10)
- # >> mesh_shape = (4, 2)
- # >> assert np.prod(mesh_shape) == len(xm.get_xla_supported_devices())
- # >> partition_spec = (0, None)
- # >> assert len(input.shape) == len(partition_spec)
- partition_spec: Tuple[int, None]
-
- __slots__ = ['global_tensor']
-
- @staticmethod
- def __new__(cls, elem: torch.Tensor, *args, **kwargs):
- # TODO(yeounoh) wrapper can take different arguments
- r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
- cls,
- elem.size(),
- strides=elem.stride(),
- storage_offset=elem.storage_offset(),
- dtype=elem.dtype,
- layout=elem.layout,
- device=elem.device,
- requires_grad=kwargs.get("requires_grad", False))
- r.global_tensor = elem.detach() if r.requires_grad else elem
- return r
-
- # Shards on the devices are materialized/available after the lazy
- # execution of the partitioned HLO graph. Each XLAShard points
- # to torch.Tensor. The shards represent a snapshot on CPU, detached
- # from the global tensor. The shard data will contain any padding
- # which results from the sharding.
- @property
- def local_shards(self) -> List[XLAShard]:
- shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor)
- replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices(
- self.global_tensor)
- zipped = zip(shards, replica_and_indices, devices)
- return [
- XLAShard(data, indices, dev, replica)
- for data, (replica, indices), dev in zipped
- ]
-
- # Load the given list of local shards into the underlying tensor's data
- # on the local devices.
- def load_local_shards_(self, shards: List[XLAShard]):
- data = [s.data for s in shards]
- devices = [s.shard_device for s in shards]
- torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices)
-
- @property
- def sharding_spec(self):
- return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor)
-
- @property
- def sharding_type(self) -> 'ShardingType':
- from torch_xla.experimental.xla_sharding import ShardingType
- sharding_type = torch_xla._XLAC._get_xla_sharding_type(self.global_tensor)
- return ShardingType(sharding_type)
-
- def __repr__(self):
- return f"XLAShardedTensor({self.global_tensor})"
-
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- """
- The dispatcher allows the unwrapped torch.Tensor to re-dispatched to the
- `xla` backend as XlaTensor, and the XlaTensor with an associated sharding spec
- to be received and wrapped as XLAShardedTensor.
- """
-
- def unwrap(elem):
- return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
-
- def wrap(elem):
- return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
-
- # no_dispatch is only needed if you use enable_python_mode.
- # It prevents infinite recursion.
- with no_dispatch():
- # re-dispatch to C++
- rs = tree_map(wrap,
- func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
- return rs
+from torch_xla.distributed.spmd.xla_sharded_tensor import *
diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py
index 1b12513fc2e..7b8c5d42b57 100644
--- a/torch_xla/experimental/xla_sharding.py
+++ b/torch_xla/experimental/xla_sharding.py
@@ -1,642 +1,10 @@
-import os
-from collections import OrderedDict, defaultdict
-from dataclasses import dataclass, field
-import torch
-import torch_xla
-import torch_xla.core.xla_model as xm
-from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor, XLAShard
-import torch_xla.runtime as xr
+# Keep this for backward compatibility.
+# TODO(yeounoh) remove after 2.2 release.
+import warnings
-import numpy as np
-import functools
-import itertools
-from typing import Tuple, Union, List, Sequence, Any, Optional, Set
-from enum import IntEnum
+warnings.warn(
+ "Importing from `torch_xla.experimental.xla_sharding` will be deprecated "
+ "after 2.2 release. Please use `torch_xla.distributed.spmd` instead.",
+ DeprecationWarning, 2)
-
-class Mesh:
- """Describe the logical XLA device topology mesh and the underlying resources.
-
- Args:
- device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped
- to an `mesh_shape` array, filling the elements using C-like index order.
-
- mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape
- of the device mesh, and each element describes the number of devices in
- the corresponding axis.
-
- axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions
- of the `devices` argument. Its length should match the rank of `devices`.
-
- Example:
- —------------------------------
- mesh_shape = (4, 2)
- num_devices = len(xm.get_xla_supported_devices())
- device_ids = np.array(range(num_devices))
- mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
- mesh.get_logical_mesh()
- >> array([[0, 1],
- [2, 3],
- [4, 5],
- [6, 7]])
- mesh.shape()
- >> OrderedDict([('x', 4), ('y', 2)])
- """
-
- device_ids: np.ndarray
- mesh_shape: Tuple[int, ...]
- axis_names: Tuple[str, ...]
-
- def __init__(self,
- device_ids: Union[np.ndarray, List],
- mesh_shape: Tuple[int, ...],
- axis_names: Tuple[str, ...] = None):
- if not isinstance(device_ids, np.ndarray):
- device_ids = np.array(device_ids)
- assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
- assert axis_names is None or (len(set(axis_names)) == len(axis_names))
- assert (len(device_ids) == np.prod(mesh_shape))
- assert len(device_ids) == len(np.unique(device_ids))
- self.device_ids = device_ids
- self.mesh_shape = mesh_shape
- self.axis_names = axis_names
- assert all(d < self.size() for d in device_ids)
-
- def size(self):
- return np.prod(self.mesh_shape)
-
- def shape(self):
- if self.axis_names is None:
- return OrderedDict(
- (dim, size) for dim, size in enumerate(self.mesh_shape))
- return OrderedDict(
- (name, size) for name, size in zip(self.axis_names, self.mesh_shape))
-
- def get_logical_mesh(self):
- return self.device_ids.reshape(self.mesh_shape)
-
- def get_axis_name_idx(self, name: str) -> int:
- if name not in self.axis_names:
- return None
- return self.axis_names.index(name)
-
- @functools.lru_cache(maxsize=None)
- def get_op_sharding(self,
- partition_spec: Tuple,
- flatten_opsharding=False) -> torch_xla._XLAC.OpSharding:
- """
- Return the OpSharding for the given partition spec. This is an expensive
- operation as the mesh grows, so the value is cached for reuse.
- """
- partition_spec = _translate_named_partition_spec(self, partition_spec)
- flat_specs = np.hstack([d for d in partition_spec])
- specs = [d for d in flat_specs if d is not None]
- assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
- f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
- assert len(specs) == len(np.unique(specs)), \
- f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."
-
- tile_assignment = _get_tile_assignment(self, partition_spec)
- if len(tile_assignment.shape) > len(partition_spec):
- # Use partial replication for sharding a tensor over a higher-rank mesh
- sharding_type = ShardingType.PARTIAL
- else:
- sharding_type = _get_sharding_type(partition_spec, self.size())
- replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
- group_assignment, replication_groups = _get_group_assignment(
- sharding_type, tile_assignment, len(partition_spec), replicate_dims)
-
- # If flatten_opsharding = True, return the flattened version of OpSharding
- if flatten_opsharding:
- return (tile_assignment.tolist(), group_assignment, replication_groups,
- int(sharding_type))
- else:
- return torch_xla._XLAC.OpSharding(tile_assignment.tolist(),
- group_assignment, replication_groups,
- int(sharding_type))
-
-
-# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
-
-
-class HybridMesh(Mesh):
- """Creates a hybrid device mesh of devices connected with ICI and DCN networks.
- The shape of logical mesh should be ordered by increasing network-intensity
- e.g. [replica, data, model] where mdl has the most network communication
- requirements.
-
- Args:
- ici_mesh_shape: shape of the logical mesh for inner connected devices.
- dcn_mesh_shape: shape of logical mesh for outer connected devices.
-
- Example:
- # This example is assuming 2 slices of v4-8.
- ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
- dcn_mesh_shape = (2, 1, 1)
-
- mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
- print(mesh.shape())
- >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
- """
- ici_mesh_shape: Tuple[int, ...]
- dcn_mesh_shape: Tuple[int, ...]
-
- def __init__(self,
- *,
- ici_mesh_shape: Tuple[int, ...],
- dcn_mesh_shape: Tuple[int, ...] = None,
- axis_names: Tuple[str, ...] = None):
- if dcn_mesh_shape == None:
- dcn_mesh_shape = tuple([1] * len(ici_mesh_shape))
- assert len(ici_mesh_shape) == len(dcn_mesh_shape)
- mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
- self.device_attributes = xr.global_runtime_device_attributes()
- self.device_attributes.sort(
- key=lambda attr: xm.parse_xla_device(attr['name'])[1])
-
- if 'slice_index' in self.device_attributes[0] and np.prod(
- dcn_mesh_shape) == 1:
- raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice')
- if 'slice_index' not in self.device_attributes[0] and np.prod(
- dcn_mesh_shape) > 1:
- raise ValueError('Invalid dcn_mesh_shape for single slice mesh')
- self.ici_mesh_shape = ici_mesh_shape
- self.dcn_mesh_shape = dcn_mesh_shape
- if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[
- 0]: # multislice
- mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape,
- self.dcn_mesh_shape)
- else:
- mesh = self._create_device_mesh(self.ici_mesh_shape)
- device_ids = mesh.flatten()
- super().__init__(device_ids, mesh_shape, axis_names)
-
- # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
- def _get_physical_tpu_mesh(self, devices: Sequence[int]) -> np.ndarray:
- r"""Rearrange TPU devices in a slice into a physical mesh.
-
- Args:
- devices: A list of device logical ordinals in a TPU slice.
-
- Returns:
- A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
- v2 and v3, global_z is instead cores_per_chip (i.e., 2).
- """
- assert xm.xla_device_hw(xm.xla_device()) == 'TPU'
- # coords is a 3-dims tuple representing the device in physical mesh
- device_coords = [self.device_attributes[d]['coords'] for d in devices]
- dims = tuple(d + 1 for d in max(device_coords))
- out = np.empty(dims, dtype=int)
- for coords, d in zip(device_coords, devices):
- out[coords[0], coords[1], coords[2]] = d
- return out
-
- # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
- def _create_device_mesh_for_nd_torus(
- self, physical_mesh: np.ndarray,
- mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
- """Assigns logical parallelism axes to physical axes of an N-D torus network.
-
- Given logical parallelism axes with sizes in `mesh_shape` and devices in an
- N-dimensional torus network represented by `physical_mesh`, maps each logical
- axis to one or more physical axes. Prefer to map more-performance-sensitive
- logical axes to larger numbers of physical axes to maximize the bandwidth
- available to them. Also prefer to assign logical axes to multiple physical
- axes of the same size (e.g., a 2D square) rather than multiple physical axes
- of different sizes when possible.
-
- Note that this routine will never split a physical axis over more than one
- logical axis (which would reduce total usable bandwidth but may sometimes be
- desired anyway). As a result, it will error out in cases where this is
- necessary to produce a valid mapping.
-
- Let's use a concrete example to explain the concepts and considerations.
-
- As an example, suppose the logical mesh is [data, model], for data and model
- parallelism respectively. Also suppose that data parallelism is less
- performance sensitive than model parallelism. Consider a 3D TPU pod slice of
- shape 4x4x16, represented by a physical mesh of shape (4, 4, 16).
-
- A TPU pod slice has equal bandwidth along all axes with wraparound links, but
- a 2D plane of size 4x4 may have faster XLA collective implementations than a
- non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want
- the more performance sensitive `model` axis to be mapped to the 4x4 XY plane.
-
- Args:
- physical_mesh: a np.ndarray of devices in the shape of the N-D torus
- physical topology.
- mesh_shape: shape of the logical mesh (size of the various logical
- parallelism axes), with axes ordered by increasing network intensity.
-
- Returns:
- An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
- each logical parallelism axis mapped to one or more physical mesh axes.
- The axis assignment (a list of length num_logical_axes, whose elements
- are tuples representing physical axis indices).
- """
- # Remaining physical axes to be assigned to logical axes.
- assignable_physical_mesh = list(physical_mesh.shape)
- # Map each logical axis to a subset of physical axes.
- assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
- # Assign logical axes from highest network intensity to lowest.
- # `mesh_shape` is assumed to ordered by lowest network intensity first, so
- # reverse it first.
- # Assigns devices to 2D or 3D logical mesh.
- for logical_axis_index, logical_axis_size in reversed(
- list(enumerate(mesh_shape))):
- for num_axes in range(3, 0, -1):
- # map a combination of devices in physical axes to the logical axis.
- axes = itertools.combinations(assignable_physical_mesh, num_axes)
- indices = itertools.combinations(
- range(len(assignable_physical_mesh)), num_axes)
- for c_axes, c_indices in zip(axes, indices):
- if np.product(c_axes) == logical_axis_size:
- assignment[logical_axis_index] = c_indices
- # Zero the assigned physical axes.
- assignable_physical_mesh = [
- 0 if i in c_indices else v
- for i, v in enumerate(assignable_physical_mesh)
- ]
- break
- if assignment[logical_axis_index]:
- # We already found an assignment from one candidate above.
- break
- else:
- # If the num_axes for loop did not break, i.e. none of the candidates work
- # goto here with this while-else construct.
- if logical_axis_size > 1:
- raise NotImplementedError(
- 'Failed to find assignment for logical_axis_index'
- f' {logical_axis_index} of size {logical_axis_size} with remaining'
- f' assignable mesh {assignable_physical_mesh}. The size of each'
- ' axis in your logical mesh must be equal to the product of'
- ' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
- ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
- )
- # Flatten the assignment
- transpose: List[int] = []
- for x in assignment:
- for y in x:
- transpose.append(int(y))
- return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment
-
- def _create_device_mesh(self,
- mesh_shape: Sequence[int],
- devices: Sequence[Any] = None) -> Sequence[int]:
- """Creates a performant device mesh.
-
- Args:
- mesh_shape: shape of logical mesh, ordered by increasing network-intensity
- e.g. [replica, data, mdl] where mdl has the most network communication
- requirements.
- devices: optionally, the devices to construct a mesh for.
-
- Returns:
- A np.ndarray of devices with mesh_shape as its shape.
- """
-
- if devices is None:
- devices = np.arange(xr.global_runtime_device_count())
- if np.prod(mesh_shape) != len(devices):
- raise ValueError(
- f'Number of devices {len(devices)} must equal the product '
- f'of mesh_shape {mesh_shape}')
- physical_mesh = self._get_physical_tpu_mesh(devices)
- device_mesh, assignment = self._create_device_mesh_for_nd_torus(
- physical_mesh, mesh_shape)
- return device_mesh
-
- # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
- def _create_hybrid_device_mesh(
- self, ici_mesh_shape: Sequence[int],
- dcn_mesh_shape: Sequence[int]) -> Sequence[int]:
- """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
-
- Args:
- ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
- by increasing network intensity, e.g. [replica, data, mdl] where mdl has
- the most network communication requirements.
- dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
- in the same order as mesh_shape.
-
- Returns:
- A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
- that can be fed into HybridMesh for hybrid parallelism.
- """
- granule_dict = defaultdict(list)
- for d, dev in enumerate(self.device_attributes):
- granule_dict[dev['slice_index']].append(d)
- # sorts devices based on slice_index.
- granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
- if np.prod(dcn_mesh_shape) != len(granules):
- raise ValueError(
- f'Number of slices {len(granules)} must equal the product of '
- f'dcn_mesh_shape {dcn_mesh_shape}')
- # creates a seperate internal mesh for each slice.
- per_granule_meshes = [
- self._create_device_mesh(ici_mesh_shape, granule)
- for granule in granules
- ]
- granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
- blocks = np.vectorize(
- lambda i: per_granule_meshes[i], otypes=[object])(
- granule_mesh)
- device_mesh = np.block(blocks.tolist())
- return device_mesh
-
-
-class ShardingType(IntEnum):
- # ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX)
- REPLICATED = 0
- MAXIMAL = 1
- TUPLE = 2
- TILED = 3
- MANUAL = 4
- PARTIAL = 5
-
-
-def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
- num_devices: int) -> ShardingType:
- sharding_type = ShardingType.TILED
- if num_devices == 1:
- sharding_type = ShardingType.MAXIMAL
- elif all(d is None for d in partition_spec):
- sharding_type = ShardingType.REPLICATED
- elif any(d is None for d in partition_spec):
- sharding_type = ShardingType.PARTIAL
- return sharding_type
-
-
-def _get_tile_assignment(
- mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int,
- None]]) -> np.ndarray:
- """
- Permute the given mesh to create the tile assignment based on the partition
- spec. Returns the tiling assignment as a numpy ndarray.
-
- If the input partition_spec combines multiple logical mesh axes over a single
- tensor axis, the resulting tiling assignment will combine the specified axes
- into a single axis.
- """
- # Flatten the partition spec and ensure that it is fully specified over the
- # mesh for permutation.
- tiled_dims = [x for x in partition_spec if x is not None]
- permutation = np.hstack(tiled_dims).tolist() if tiled_dims else []
- missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation))
- tile_assignment = mesh.get_logical_mesh().transpose(permutation +
- missing_axes)
-
- # For any tuples in the partition_spec, the grouped axes will be adjacent
- # after the permutation. Combine these dimensions into a single axis.
- for i, spec in enumerate(tiled_dims):
- if isinstance(spec, tuple):
- shape = tile_assignment.shape
- tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) +
- shape[i + len(spec):])
-
- return tile_assignment
-
-
-# Produce group assignment for partial replication. Partial replication tiles
-# groups (a.k.a. sub-groups) where the shards are fully replicated within each
-# sub-group. `replication_groups` is a list of groups as lists, where each group
-# contains the participating device IDs. `group_assignment` describes the group
-# placement and the overall mesh, where each element is the group ID.
-# The tile_assignment should be the result of `_get_tile_assignment` so that all
-# tiled dimensions are in the first axes and replicated dimensions are in the
-# remaining axes.
-def _get_group_assignment(sharding_type: ShardingType,
- tile_assignment: np.ndarray, tensor_rank: int,
- replicate_dims: Set[int]) -> Tuple[List, List]:
- group_assignment = list()
- replication_groups = list()
- if sharding_type is ShardingType.PARTIAL:
- # Shard across groups and replicate within subgroups; replicated dims
- # will be used to group replication devices.
- tile_shape = tile_assignment.shape
- # When creating the tile assignment, the mesh is permuted so that the first
- # few axes are used for tiling.
- tile_dims = range(tensor_rank - len(replicate_dims))
- group_list = [tile_assignment]
- for d in tile_dims:
- _group_list = list()
- for group_members in group_list:
- _group_list += np.split(group_members, tile_shape[d], d)
- group_list = _group_list
- replication_groups = [group.flatten().tolist() for group in group_list]
-
- mesh_axis = itertools.count()
- group_tile_shape = [
- 1 if d in replicate_dims else tile_shape[next(mesh_axis)]
- for d in range(tensor_rank)
- ]
- group_assignment = np.arange(len(replication_groups)).reshape(
- tuple(group_tile_shape)).tolist()
- return group_assignment, replication_groups
-
-
-def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
- _partition_spec = list()
- for p in partition_spec:
- if type(p) is tuple:
- assert not any(type(x) is tuple
- for x in p), 'Partition spec cannot contain nested tuples'
- _partition_spec.append(_translate_named_partition_spec(mesh, p))
- elif (p is None) or (type(p) is int):
- _partition_spec.append(p)
- elif type(p) is str:
- idx = mesh.get_axis_name_idx(p)
- if idx is None:
- raise ValueError(f"Axis name {p} is not defined in the given mesh")
- _partition_spec.append(idx)
- else:
- raise ValueError(
- f"Spec type {type(p)} is not supported in partition spec")
- return tuple(_partition_spec)
-
-
-@xr.requires_pjrt
-def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
- mesh: Mesh,
- partition_spec: Tuple[Union[Tuple, int, str, None]],
- use_dynamo_custom_op: bool = False) -> XLAShardedTensor:
- """
- Annotates the tensor provided with XLA partition spec. Internally,
- it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
- Args:
- t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec.
-
- mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
-
- partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or
- `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str.
- This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
- When a tuple is specified, the corresponding input tensor axis will be sharded along all
- logical axes in the tuple. Note that the order the mesh axes are specified in the tuple
- will impact the resulting sharding.
- For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
- >> input = torch.randn(8, 10)
- >> mesh_shape = (4, 2)
- >> partition_spec = (0, None)
-
- dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding
- to make itself recognizeable and traceable by dynamo.
-
- Examples
- —------------------------------
- mesh_shape = (4, 2)
- num_devices = xr.global_runtime_device_count()
- device_ids = np.array(range(num_devices))
- mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
-
- # 4-way data parallel
- input = torch.randn(8, 32).to(xm.xla_device())
- xs.mark_sharding(input, mesh, (0, None))
-
- # 2-way model parallel
- linear = nn.Linear(32, 10).to(xm.xla_device())
- xs.mark_sharding(linear.weight, mesh, (None, 1))
- """
- num_devices = xr.global_runtime_device_count()
- assert num_devices > 0, "This requires XLA supported device(s)."
- assert mesh.size() == num_devices, \
- f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
- # We only allow fully specified `partition_spec` to be applicable, as opposed
- # to filling in the unspecified replicated dims. Fully specified `partiion_spec`
- # should be of the same rank as `t`. This is to support partial replication
- # where the group assignment may vary with different input ranks.
- assert len(t.shape) == len(partition_spec), \
- f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
-
- if use_dynamo_custom_op:
- tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(
- partition_spec, flatten_opsharding=True)
-
- if isinstance(t, XLAShardedTensor):
- torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(
- t.global_tensor, tile_assignment, group_assignment,
- replication_groups, sharding_type)
- return t
- else:
- torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(
- t, tile_assignment, group_assignment, replication_groups,
- sharding_type)
- return XLAShardedTensor(t)
- else:
- op_sharding = mesh.get_op_sharding(partition_spec)
-
- if isinstance(t, XLAShardedTensor):
- torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
- return t
- else:
- torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
- return XLAShardedTensor(t)
-
-
-def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
- """Clear sharding annotation from the input tensor and return a `cpu` casted tensor."""
- torch_xla._XLAC._xla_clear_sharding(t)
- if isinstance(t, XLAShardedTensor):
- return t.global_tensor
- return t
-
-
-def wrap_if_sharded(x: Any) -> Any:
- """
- If the input is a sharded tensor, return an XLAShardedTensor wrapping it.
- Otherwise, returns the input.
- """
- if (isinstance(x, torch.Tensor) and not isinstance(x, XLAShardedTensor) and
- x.device.type == 'xla' and
- torch_xla._XLAC._get_xla_sharding_type(x) is not None):
- return XLAShardedTensor(x)
- return x
-
-
-@dataclass
-class ShardingSpec:
- mesh: Mesh
- partition_spec: Tuple[Union[int, None]]
- minibatch: Optional[bool] = False
-
- # Derived fields
- _tile_assignment: List[int] = field(init=False)
- _group_assignment: List[int] = field(init=False)
- _replication_groups: List[int] = field(init=False)
- _sharding_type: ShardingType = field(init=False)
-
- @xr.requires_pjrt
- def __post_init__(self):
- mesh = self.mesh
- partition_spec = _translate_named_partition_spec(mesh, self.partition_spec)
- tile_assignment = _get_tile_assignment(mesh, partition_spec)
- self._tile_assignment = tile_assignment.tolist()
- self._sharding_type = _get_sharding_type(partition_spec,
- xr.global_runtime_device_count())
- replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
- self._group_assignment, self._replication_groups = _get_group_assignment(
- self._sharding_type, tile_assignment, len(partition_spec),
- replicate_dims)
-
- def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
- """
- Create an XlaShardingSpec for the given tensor. If the tensor is
- incompatible with the ShardingSpec, returns None.
- """
- if not self.can_apply(t):
- return None
- return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
- self._group_assignment,
- self._replication_groups,
- int(self._sharding_type),
- self.minibatch)
-
- def can_apply(self, t: torch.Tensor) -> bool:
- """
- Test whether the ShardingSpec is compatible with the given torch.Tensor.
- """
- return len(t.shape) == len(self.partition_spec)
-
- def apply(self, t: torch.Tensor):
- # TODO(yeounoh) use virtual device interface when available.
- assert (t.device == xm.xla_device())
- mark_sharding(t, self.mesh, self.partition_spec)
-
-
-class XLAPatchedLinear(torch.autograd.Function):
- """
- A patched version of `torch.nn.functional.linear` that uses einsum instead
- of torch.matmul which will flatten the tensors to 2D and collide the sharded
- dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
- to propagate the sharding annotation.
-
- TODO (alanwaketan): Let's patch it on the dispatcher level.
- """
-
- @staticmethod
- def forward(ctx, input, weight, bias=None):
- # bias is an optional argument
- ctx.save_for_backward(input, weight, bias)
- with torch.no_grad():
- product = torch.einsum('...n,mn->...m', input, weight)
- if bias is None:
- return product
- return product + bias
-
- @staticmethod
- def backward(ctx, grad_output):
- input, weight, bias = ctx.saved_tensors
- grad_input = grad_weight = grad_bias = None
-
- if ctx.needs_input_grad[0]:
- grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
- if ctx.needs_input_grad[1]:
- grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
- if bias is not None and ctx.needs_input_grad[2]:
- grad_bias = torch.einsum('...m->m', grad_output)
-
- return grad_input, grad_weight, grad_bias
-
-
-def xla_patched_nn_linear_forward(m, input):
- return XLAPatchedLinear.apply(input, m.weight, m.bias)
+from torch_xla.distributed.spmd.xla_sharding import *