From 63ea76cdf0891d3fc7865757c18f270819cc40f8 Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Fri, 27 Oct 2023 15:10:06 -0700 Subject: [PATCH] Add information about on-going DTensor API in spmd.md (#5735) --- docs/spmd.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/spmd.md b/docs/spmd.md index 8337bbd5af9..8b2f886880d 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -202,6 +202,20 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i There is also an ongoing effort to integrate XLAShardedTensor into DistributedTensor API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)]. +### DTensor Integration +PyTorch has prototype-released [DTensor](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md) in 2.1. +We are integrating PyTorch/XLA SPMD into DTensor API [RFC](https://github.com/pytorch/pytorch/issues/92909). We have a proof-of-concept integration for `distribute_tensor`, which calls `mark_sharding` annotation API to shard a tensor and its computation using XLA: +```python +import torch +from torch.distributed import DeviceMesh, Shard, distribute_tensor + +# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD. +mesh = DeviceMesh("xla", list(range(world_size))) +big_tensor = torch.randn(100000, 88) +my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) +``` + +This feature is experimental and stay tuned for more updates, examples and tutorials in the upcoming releases. ### Sharding-Aware Host-to-Device Data Loading