diff --git a/master/_modules/index.html b/master/_modules/index.html index 9a5a00b4731..5e714bc6cce 100644 --- a/master/_modules/index.html +++ b/master/_modules/index.html @@ -227,7 +227,7 @@
pip3 install torch==2.3.0
-# GPU whl for python 3.10 + cuda 12.1
-pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
+pip3 install torch==2.2.0
+pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl
-Wheels for other Python version and CUDA version can be found here.
GSPMD is an automatic parallelization system for common ML workloads. The XLA compiler will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale.
-*Figure 1. Comparison of two different execution strategies, (a) for non-SPMD and (b) for SPMD.*
+*Figure 1. Comparison of two different execution strategies, (a) for non-SPMD and (b) for SPMD.*
To support GSPMD in PyTorch/XLA, we are introducing a new execution mode. Before GSPMD, the execution mode in PyTorch/XLA assumed multiple model replicas, each with a single core (Figure 1.a). This mode of execution, as illustrated in the above suits data parallelism frameworks, like the popular PyTorch Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP), but is also limited in that a replica can only reside on one device core for execution. PyTorch/XLA SPMD introduces a new execution mode that assumes a single replica with multiple cores (Figure 1.b), allowing a replica to run across multiple device cores. This shift unlocks more advanced parallelism strategies for better large model training performance.
PyTorch/XLA SPMD is available on the new PJRT runtime. To enable PyTorch/XLA SPMD execution mode, the user must call [use_spmd() API](https://github.com/pytorch/xla/blob/b8b484515a97f74e013dcf38125c44d53a41f011/torch_xla/runtime.py#L214)
.
import torch_xla.runtime as xr
@@ -3330,8 +3328,8 @@ Simple Example & Sharding Aannotation APIrepo.
For a given cluster of devices, a physical mesh is a representation of the interconnect topology.
We derive a logical mesh based on this topology to create sub-groups of devices which can be used for partitioning different axes of tensors in a model.
@@ -3375,8 +3373,8 @@partition_spec has the same rank as the input tensor. Each dimension describes how the corresponding input tensor dimension is sharded across the device mesh (logically defined by mesh_shape). partition_spec
is a tuple of device_mesh
dimension index
or None. The index can be an int
or str
, if the corresponding mesh dimension is named. This specifies how each input rank is sharded (index
to mesh_shape
) or replicated (None
).
# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (4, 2), ('data', 'model'))
@@ -3671,7 +3669,7 @@ SPMD Debugging Toolshard placement visualization debug tool for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use visualize_tensor_sharding
to visualize sharded tensor, or you could use visualize_sharding
to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with visualize_tensor_sharding
or visualize_sharding
:
pytorch.distributed._tensor.distribute_module
with auto-policy
and xla
:
-```python
+```python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
@@ -3720,7 +3718,7 @@ ¶
model = MyModule() # nn.module sharded_model = distribute_module(model, device_mesh, auto_policy) -```
+```Optionally, one can set the following options/env-vars to control the behvaior of the XLA-based auto-sharding pass:
Currently, gradient checkpointing needs to be applied to the module before the FSDP wrapper. Otherwise, recursively loop into children modules will end up with infinite loop. We will fix this issue in the future releases.
Example usage:
from torch_xla.distributed.fsdp import checkpoint_module
@@ -3971,9 +3969,9 @@ HuggingFace Llama 2 ExampleWhat is PyTorch/XLA SPMD?