diff --git a/README.md b/README.md
index fe5bf8c..e23cf93 100644
--- a/README.md
+++ b/README.md
@@ -1,63 +1,55 @@
-# veScale: A PyTorch Native LLM Training Framework
+
+
+
-## Coming Soon
+# A PyTorch Native LLM Training Framework
-We are refactoring our [internal LLM training system](https://arxiv.org/abs/2402.15627) components to meet open source standard. The tentative timeline is as follows:
+_**An Industrial-Level Framework for Easy-of-Use**_
-1. by mid April, 4D parallelism (tensor parallelism, sequence parallelism, data parallelism and ZERO) examples for nanoGPT, Llama2 and Mixtral models
-2. by end of May, fast checkpointing system
-3. by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training
+- 🔥 **PyTorch Native**: veScale is rooted in PyTorch-native data structures, operators, and APIs, enjoying the ecosystem of PyTorch that dominates the ML world.
-## Installation
+- 🛡 **Zero Model Code Change**: veScale decouples distributed system design from model architecture, requiring near-zero or zero modification on the model code of users.
-### From Source
+- 🚀 **Single Device Abstraction**: veScale provides single-device semantics to users, automatically distributing and orchestrating model execution in a cluster of devices.
-#### Install a Patched Version of PyTorch
+- 🎯 **Automatic Parallelism Planning**: veScale parallelizes model execution with a synergy of strategies (tensor, sequence, data, ZeRO, pipeline parallelism) under semi- or full-automation [coming soon].
-```bash
-bash patches/build_pytorch_w_patch.sh
-```
+- ⚡ **Eager & Compile Mode**: veScale supports not only Eager-mode automation for parallel training and inference but also Compile-mode for ultimate performance [coming soon].
-This will compile and install a patched version of PyTorch (based on v2.2.1_rc3).
-The patch code can be found here: [PyTorch-Patch](patches/patched_pytorch_v2.2.1_rc3.patch)
+- đź“€ **Automatic Checkpoint Resharding**: veScale manages distributed checkpoints automatically with online resharding across different cluster sizes and different parallelism strategies.
-#### Install a Patched Version of TorchDistX
-```bash
-bash patches/build_torchdistX_w_patch.sh
-```
+## Coming Soon
-This will compile and install a patched version of TorchdistX (based on its master).
-The patch code can be found here: [TorchDistX-Patch](patches/patched_torchdistX_9c1b9f.patch)
+_**veScale**_ is still in its early phase. We are refactoring our [internal LLM training system](https://arxiv.org/abs/2402.15627) components to meet open source standard. The tentative timeline is as follows:
-#### Install veScale
+- by end of May, fast checkpointing system
-```bash
-pushd python && pip3 install -r requirements.txt && pip3 install -e . && popd
-```
+- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training
-This will install veScale and its dependencies.
+## Table of Content ([web view](https://volcengine.github.io/veScaleWeb/))
-### Docker Image
+**[Introduction](./docs/texts/introduction.md)**
-#### Build the Docker Image
+**[Quick Start](./docs/texts/quick-start.md)**
-Make sure it is in the Vescale directory.
+**[DTensor](./vescale/dtensor/README.md)**
-```bash
-docker build .
-```
-It may take a while to build the image.
+**Parallel**
+ * [Overview](./docs/texts/parallel_overview.md)
+ * [Tensor Parallel & Sequence Parallel](./vescale/dmodule/README.md)
+ * [Data Parallel](./vescale/ddp/README.md)
+ * [Optimizer Parallel](./vescale/optim/README.md)
+ * [Pipeline Parallel](./vescale/pipe/README.md)
+ * [nD Device Mesh](./vescale/devicemesh_api/README.md)
-Once the building process is finished, you can `docker run` with the id.
+**Plan**
+ * [Auto TP & SP Plan](./vescale/dmp/README.md)
+**[Checkpoint](./vescale/checkpoint/README.md)**
+## [We Are Hiring!](https://volcengine.github.io/veScaleWeb/misc/join-us.html) ##
## [License](./LICENSE)
-The veScale Project is under the Apache License v2.0.
-
-## Acknowledgement
-
-veScale team would like to sincerely acknowledge the assistance of and collaboration with
-the [PyTorch DTensor team](https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor).
\ No newline at end of file
+The veScale Project is under the Apache License v2.0.
\ No newline at end of file
diff --git a/docs/pictures/ddp.png b/docs/pictures/ddp.png
new file mode 100644
index 0000000..8587f91
Binary files /dev/null and b/docs/pictures/ddp.png differ
diff --git a/docs/pictures/dmodule.png b/docs/pictures/dmodule.png
new file mode 100644
index 0000000..eb5828a
Binary files /dev/null and b/docs/pictures/dmodule.png differ
diff --git a/docs/pictures/doptimizer.png b/docs/pictures/doptimizer.png
new file mode 100644
index 0000000..0239b1a
Binary files /dev/null and b/docs/pictures/doptimizer.png differ
diff --git a/docs/pictures/dtensor.png b/docs/pictures/dtensor.png
new file mode 100644
index 0000000..278a1f5
Binary files /dev/null and b/docs/pictures/dtensor.png differ
diff --git a/docs/pictures/icon.png b/docs/pictures/icon.png
new file mode 100644
index 0000000..5fd3c12
Binary files /dev/null and b/docs/pictures/icon.png differ
diff --git a/docs/pictures/overview.png b/docs/pictures/overview.png
new file mode 100644
index 0000000..6fa7ea8
Binary files /dev/null and b/docs/pictures/overview.png differ
diff --git a/docs/pictures/parallel5d.png b/docs/pictures/parallel5d.png
new file mode 100644
index 0000000..d5f35c3
Binary files /dev/null and b/docs/pictures/parallel5d.png differ
diff --git a/docs/pictures/pytorch.png b/docs/pictures/pytorch.png
new file mode 100644
index 0000000..e637cf3
Binary files /dev/null and b/docs/pictures/pytorch.png differ
diff --git a/docs/pictures/tldr.png b/docs/pictures/tldr.png
new file mode 100644
index 0000000..6b03f60
Binary files /dev/null and b/docs/pictures/tldr.png differ
diff --git a/docs/pictures/vedevicemesh.png b/docs/pictures/vedevicemesh.png
new file mode 100644
index 0000000..ebfb940
Binary files /dev/null and b/docs/pictures/vedevicemesh.png differ
diff --git a/docs/pictures/vescale-logo-dark.png b/docs/pictures/vescale-logo-dark.png
new file mode 100644
index 0000000..037508a
Binary files /dev/null and b/docs/pictures/vescale-logo-dark.png differ
diff --git a/docs/pictures/vescale-logo-light.png b/docs/pictures/vescale-logo-light.png
new file mode 100644
index 0000000..b2df4a1
Binary files /dev/null and b/docs/pictures/vescale-logo-light.png differ
diff --git a/docs/texts/introduction.md b/docs/texts/introduction.md
new file mode 100644
index 0000000..439a6b2
--- /dev/null
+++ b/docs/texts/introduction.md
@@ -0,0 +1,58 @@
+# veScale: A PyTorch Native LLM Training Framework
+
+## TLDR
+
+An _**Industrial-Level**_ Framework for _**Easy-of-Use**_:
+
+
+
+(`*` is under development.)
+
+## Why veScale
+
+The era of giant models today calls forth distributed training.
+Despite countless distributed training frameworks that have been published in the past decade (to name a few), few have excelled at the _**Ease-of-Use**_ and development extensibility demanded by real industry production,
+as the quality most favored for a framework is often the _**Ease-of-Use**_ instead of pure _Performance_.
+Companies developing 100s~1000s models a week benefit the most from a framework that is both easy to use and extend, and provides elegant encapsulation of models and clean APIs.
+
+The _**Ease-of-Use**_ of a framework for training and developing LLM lies in the following essentials:
+
+- 🔥 **PyTorch Native**: _PyTorch_ ecosystem dominates the ML world and owns 92% of models on _HuggingFace_ and 70% of research on _Papers with Code_; Alienating from _PyTorch_ ecosystem makes a framework hard to adapt and extend.
+
+- 🛡 **Zero Model Code Change**: Users' model code should remain untouched, instead of being intertwined with framework code, which requires users to not only manually rewrite the model for distributed training with tons of care, but also painfully debug within the deep coupled model and framework code.
+
+- 🚀 **Single Device Abstraction**: Model developers should focus on developing model architecture itself with single device semantics, rather than being distracted by the complex and error-prone management of multiple devices and diverse interconnects in distributed environments.
+
+- 🎯 **Automatic Parallelism Planning**: Gigantic models cannot be trained without _nD Parallelism_ (_Tensor, Sequence, Data, ZeRO, Pipeline Parallelism, etc._). Users' giant models should be automatically scaled by a framework for _nD_ parallel training, instead of being manually planned and tuned for each operator or layer under different cluster settings, which takes forever.
+
+- ⚡ **Eager & Compile Mode**: Users should enjoy both _Eager_ and _Compile_ mode offered by a framework with:
+ - _Eager_ mode for fast development, convenient debugging, and customization with callbacks and control flows;
+ - _Compile_ mode for ultimate performance boost with a single click.
+
+- đź“€ **Automatic Checkpoint Resharding**: Training models and optimizer states should be saved/loaded automatically and performantly in distributed settings, and can even be _online resharded_ across different cluster sizes and different _nD Parallelism_.
+
+## What is veScale
+
+**veScale**'s overview is as follows:
+
+
+
+We take an initial step to develop an _**Industry-Level**_ framework, **veScale**, that focuses _**Ease-of-Use**_ for scaling LLM training, by combining _PyTorch Nativeness_ and _Automatic Parallelism*_.
+
+Ideally, **veScale** only expects model developers to write a simple model code with native _torch.nn.Module_ under _Zero Code Change_ as if running on a _Single Device_, and then **veScale** will automatically parallelize it across a cluster of devices in a _nD Parallelism_ search space with all the optimizations and heavy lifting handled transparently.
+
+Unlike existing frameworks that rely on _Compile_ mode and a "perfect model graph" for _Automatic Parallelism_, **veScale** is inventing an _Eager-Mode-ONLY*_ _Automatic Parallelism_ that does not rely on the model graph at all.
+Furthermore, **veScale** is also developing a _Mixed Mode_* of partial _Eager_ and partial _Compile_.
+
+**veScale** is designed and implemented on top of a primitive called _DTensor_ that provides a global tensor semantic with local shards distributed on multiple devices.
+**veScale** extends and enhances the _PyTorch DTensor_ for our production standard, and further develops the _Auto-Plan*_ and _Auto-Paralleize_ with a unified configuration and API.
+
+Furthermore, **veScale** also supports online _Auto-Reshard_ for distributed checkpoints, which will be open-sourced as a new project -- **OmniStore**.
+
+(`*` is under development)
+
+## Status of veScale
+
+**veScale** is still in its early phase.
+
+The tentative open-source timeline can be found in the **veScale** [**repo**](https://github.com/volcengine/veScale/tree/main).
\ No newline at end of file
diff --git a/docs/texts/parallel_overview.md b/docs/texts/parallel_overview.md
new file mode 100644
index 0000000..78ce20b
--- /dev/null
+++ b/docs/texts/parallel_overview.md
@@ -0,0 +1,50 @@
+# veScale Parallel Overview
+
+The overview of veScale _n-D parallelism_ is as follows:
+
+
+
+(`*` is under development)
+
+The _Auto-Parallelize_ block takes the untouched _Model_ from the user and _Parallel Plan_ (given by manual effort, prefined for each model type, or automatically generated from _Auto-Plan*_) and then parallelizes the single-device model into _nD Parallelism_ across a mesh of devices.
+
+veScale's _nD Parallelism_ follows a decoupled design where each D of parallelism is handled by an independent sub-block (e.g., _DModule_ only handles _Tensor & Sequence Parallel_, without coupling with other _Parallel_).
+In contrast to the conventional _coupled_ design that intertwines all parallelism together, such a _decoupled_ _nD Parallelism_ enjoys composability, debuggability, explainability, and extensibility, all of which are of great value for hyper-scale training in production.
+
+## 4D Parallelisim API
+
+Our 4D parallelism (_Tensor, Sequence, Data, and ZeRO2_) is as follows:
+
+``` python
+# zero model code change
+from import ,
+
+# create fake model without actual memory usage (optional)
+fake_model = deferred_init(, )
+
+# initialize 4D device mesh
+mesh = init_device_mesh("cuda", (dp_zero_size, tp_sp_size), mesh_dim_names=["DP_ZERO", "TP_SP"])
+
+# parallelize model in tp & sp
+from import sharding_plan
+real_tp_sp_model = parallelize_module(fake_model, mesh["TP_SP"], sharding_plan)
+
+# parallelize model in dp
+ddp_model = DDP(real_tp_sp_model, mesh["DP_ZERO"])
+
+# parallelize model with zero2
+doptimizer = DistributedOptimizer(torch.optim.AdamW, models=[ddp_model])
+
+# train model as if on a single device
+for x in range(dataset):
+ loss = ddp_model(x)
+ loss.backward()
+ doptimizer.step()
+ doptimizer.zero_grad()
+```
+
+More examples can be found in: `/examples/`.
+
+## 5D Parallelisim API
+
+Coming Soon
diff --git a/docs/texts/quick-start.md b/docs/texts/quick-start.md
new file mode 100644
index 0000000..5acfffa
--- /dev/null
+++ b/docs/texts/quick-start.md
@@ -0,0 +1,52 @@
+# Quick Start
+
+First, find the **veScale** [**repo**](https://github.com/volcengine/veScale/tree/main).
+
+## Installation
+
+### From Source
+
+#### Install a Patched Version of PyTorch
+
+```bash
+bash [repo]/patches/build_pytorch_w_patch.sh
+```
+
+This will compile and install a patched version of PyTorch.
+
+#### Install a Patched Version of TorchDistX
+
+```bash
+bash [repo]/patches/build_torchdistX_w_patch.sh
+```
+
+This will compile and install a patched version of TorchdistX (based on its master).
+
+#### Install veScale
+
+```bash
+pushd python && pip3 install -r requirements.txt && pip3 install -e . && popd
+```
+
+This will install **veScale** and its dependencies.
+
+### Docker Image
+
+#### Build the Docker Image
+
+Make sure it is in the veScale directory.
+
+```bash
+docker build .
+```
+It may take a while to build the image.
+
+Once the building process is finished, you can `docker run` with the id.
+
+## Run Examples
+
+- Nano GPT: `/examples/nanogpt_4D_finetune`
+
+- Open LLAMA: `/examples/open_llama_4D_benchmark`
+
+- Mixtral: `/examples/mixtral_4D_benchmark`
\ No newline at end of file
diff --git a/examples/mixtral_4D_benchmark/mixtral_train.py b/examples/mixtral_4D_benchmark/mixtral_train.py
index 91e94eb..d6674e4 100644
--- a/examples/mixtral_4D_benchmark/mixtral_train.py
+++ b/examples/mixtral_4D_benchmark/mixtral_train.py
@@ -40,10 +40,13 @@ def estimate_mixtral(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
# MixtralMoE consists of 3 linear layers.
ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length
- attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size
+ # GQA
+ head_size = config.hidden_size // config.num_attention_heads
+ attn_q = 2 * bsz * sqence_length * config.hidden_size * config.hidden_size
+ attn_kv = 2 * 2 * bsz * sqence_length * config.hidden_size * config.num_key_value_heads * head_size
attn_mask = 2 * sqence_length * config.hidden_size
- attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
- attn = attn_qkv + attn_mask + attn_proj
+ attn_proj = 2 * config.hidden_size * config.hidden_size * bsz * sqence_length
+ attn = attn_q + attn_kv + attn_mask + attn_proj
return embed + (ff + attn) * config.num_hidden_layers
diff --git a/examples/nanogpt_4D_finetune/finetune_4D.py b/examples/nanogpt_4D_finetune/finetune_4D.py
index 291a7a3..9e50bf5 100644
--- a/examples/nanogpt_4D_finetune/finetune_4D.py
+++ b/examples/nanogpt_4D_finetune/finetune_4D.py
@@ -33,7 +33,7 @@
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group, get_rank
from model import GPTConfig, GPT
-from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale import distribute_tensor
from vescale.dmodule.api import parallelize_module
@@ -114,7 +114,7 @@ def main():
torch.cuda.set_device(device)
init_process_group(backend=backend, world_size=world_size, rank=rank)
- mesh = veDeviceMesh.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
+ VESCALE_DEVICE_MESH.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = get_rank() // tp_size
else:
rank = 0
@@ -162,8 +162,8 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
else:
x, y = x.to(device), y.to(device)
if ddp:
- x = distribute_tensor(x, mesh["TP"], [Replicate()])
- y = distribute_tensor(y, mesh["TP"], [Replicate()])
+ x = distribute_tensor(x, VESCALE_DEVICE_MESH["TP"], [Replicate()])
+ y = distribute_tensor(y, VESCALE_DEVICE_MESH["TP"], [Replicate()])
return x, y
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
@@ -235,10 +235,10 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
# + + + parallelize the model and wrap it with DDP using veScale APIs
if ddp:
- model = parallelize_module(model, mesh["TP"], nanoGPT_plan)
+ model = parallelize_module(model, VESCALE_DEVICE_MESH["TP"], nanoGPT_plan)
model = DDP(
model,
- data_pg_or_device_mesh=mesh["DP"],
+ data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=DDP_grads_in_fp32,
overlap_grad_reduce=False,
use_distributed_optimizer=use_DO,
diff --git a/test/checkpoint/common_func.py b/test/checkpoint/common_func.py
index d3b5525..d4e8ca8 100644
--- a/test/checkpoint/common_func.py
+++ b/test/checkpoint/common_func.py
@@ -122,22 +122,22 @@ def build_gpt_model_optimizer_and_dataset(init_method, dp_size=1, tp_size=1):
open_source = False
try:
- from vescale.devicemesh_api import veDeviceMesh
+ from vescale.devicemesh_api import VESCALE_DEVICE_MESH
except ImportError:
open_source = True
- device_mesh = veDeviceMesh.init_device_mesh(
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(dp_size, tp_size),
mesh_dim_names=("DP", "TP"),
)
# Enable tensor Parallel
- tp_gpt = parallelize_module(gpt, device_mesh["TP"], nanoGPT_plan)
+ tp_gpt = parallelize_module(gpt, VESCALE_DEVICE_MESH["TP"], nanoGPT_plan)
# Enable data Parallel
ddp_gpt = DDP(
tp_gpt,
- data_pg_or_device_mesh=device_mesh["DP"],
+ data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
@@ -280,24 +280,22 @@ def get_open_llama_model(layer_number=None):
def get_open_llama_model_optimizer(dp_size, tp_size, layer_number=None):
- from vescale.devicemesh_api import veDeviceMesh
+ from vescale.devicemesh_api import VESCALE_DEVICE_MESH
- device_mesh = veDeviceMesh.init_device_mesh(
- "cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"), check_uniqueness=True
- )
+ VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"), check_uniqueness=True)
# Set 4 layers to avoid timeout on CI
# Use 32 layers when running on training platform
vescale_decoder, config = get_open_llama_model(layer_number=layer_number)
vescale_decoder = parallelize_module(
vescale_decoder,
- device_mesh["TP"],
+ VESCALE_DEVICE_MESH["TP"],
sharding_plan,
)
ddp_decoder = DDP(
vescale_decoder,
- data_pg_or_device_mesh=device_mesh["DP"],
+ data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
diff --git a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
index a6d7433..45f2e81 100644
--- a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
+++ b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
@@ -18,7 +18,7 @@
import torch.distributed as dist
from common_dtensor import DTensorTestBase, with_comms, skip_unless_torch_gpu
from torch.testing._internal.common_utils import run_tests
-from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
import vescale
from vescale.dtensor.placement_types import Replicate
@@ -48,17 +48,16 @@ def test_save(self):
)
# turn off 'check_uniqueness' to allow multiple updates of global device mesh during runtime
- device_mesh = veDeviceMesh.init_device_mesh(
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(1, 2, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
- tp_sub_mesh = device_mesh["TP"]
# Do fwd+bwd+step on the first data
for X, Y in data_set[:1]:
- input = vescale.distribute_tensor(X, device_mesh["TP"], [Replicate()])
- output = vescale.distribute_tensor(Y, device_mesh["TP"], [Replicate()])
+ input = vescale.distribute_tensor(X, VESCALE_DEVICE_MESH["TP"], [Replicate()])
+ output = vescale.distribute_tensor(Y, VESCALE_DEVICE_MESH["TP"], [Replicate()])
dist_optimizer.zero_grad()
_, output = ddp_gpt(input, output)
loss = output.mean()
diff --git a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py
index 5440d98..d37fb33 100644
--- a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py
+++ b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py
@@ -18,7 +18,7 @@
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests
-from vescale.devicemesh_api import veDeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from common_dtensor import DTensorTestBase, with_comms
import vescale
@@ -54,11 +54,13 @@ def test_open_llama2_with_ddp(self):
ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer}
vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state)
# For processes with dp_rank = 0, dump model state_dict
- if veDeviceMesh.get_data_parallel_rank() == 0:
+ if VESCALE_DEVICE_MESH.get_data_parallel_rank() == 0:
dumped_model_sd = {}
for k, v in ddp_decoder.state_dict().items():
dumped_model_sd[k] = v._local_tensor
- torch.save(dumped_model_sd, f"open_llama_dp_reshard_model_tp_{veDeviceMesh.get_tensor_parallel_rank()}.pt")
+ torch.save(
+ dumped_model_sd, f"open_llama_dp_reshard_model_tp_{VESCALE_DEVICE_MESH.get_tensor_parallel_rank()}.pt"
+ )
# Save merged optimizer state dict
optimizer_state = ve_optimizer.state_dict()
@@ -86,7 +88,7 @@ def test_open_llama2_with_ddp(self):
vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state)
# Load model state dict and verify it
dumped_model_sd = torch.load(
- f"open_llama_dp_reshard_model_tp_{veDeviceMesh.get_tensor_parallel_rank()}.pt", map_location="cpu"
+ f"open_llama_dp_reshard_model_tp_{VESCALE_DEVICE_MESH.get_tensor_parallel_rank()}.pt", map_location="cpu"
)
current_model_sd = ddp_decoder.state_dict()
diff --git a/test/dmodule/test_dfactory.py b/test/dmodule/test_dfactory.py
index a4808f6..74ca372 100644
--- a/test/dmodule/test_dfactory.py
+++ b/test/dmodule/test_dfactory.py
@@ -134,7 +134,7 @@ def _match_factory_dfactory(self, factory, dfactory, global_shape, placements, d
for actual, golden in zip(actuals, goldens):
self.assertTrue(isinstance(actual, DTensor))
self.assertTrue(isinstance(golden, DTensor))
- if factory in [torch.empty]: # TODO: fix torch.rand to equal
+ if factory in [torch.empty]:
is_match = dtensor._utils._equal_meta_data(actual, golden, exact_device=True)
else:
is_match = dtensor.equal(actual, golden)
diff --git a/test/dtensor/device_mesh/test_initialize.py b/test/dtensor/device_mesh/test_initialize.py
new file mode 100644
index 0000000..ed9caaf
--- /dev/null
+++ b/test/dtensor/device_mesh/test_initialize.py
@@ -0,0 +1,83 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+import warnings
+import torch
+from torch.distributed import init_process_group, new_group
+from common_dtensor import DTensorTestBase
+from torch.testing._internal.common_utils import run_tests
+from vescale.dtensor.device_mesh import DeviceMesh
+
+
+class DeviceMeshInitializeTest(DTensorTestBase):
+ @property
+ def world_size(self) -> int:
+ return 4
+
+ def _manual_setup(self):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12345"
+ self.device_type = "cuda"
+ init_process_group(
+ backend="nccl",
+ rank=self.rank,
+ world_size=self.world_size,
+ ) # In DTensorTestBase, but do not use @with_comm
+
+ def test_init_process_group(self):
+ """
+ Test DeviceMesh's initialization reaction to map rank to cuda device
+ when users fail to do so. We simulate the situation by setting up distributed
+ environment partially. DeviceMesh initialization takes as input a process group.
+ """
+ self._manual_setup()
+ input_pg = new_group(ranks=list(range(self.world_size)))
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ device_mesh = DeviceMesh("cuda", torch.arange(self.world_size), pg=input_pg, _validate_mesh=False)
+ if self.rank != 0:
+ # validate on rank > 0 since torch.cuda.current_device() returns 0 when the device hasn't been set.
+ all_warnings = [str(item.message) for item in w]
+ self.assertEqual(len(all_warnings), 2)
+ self.assertTrue(
+ any(
+ "Construction from given ProcessGroup is only supported for 1D mesh currently." in warn
+ for warn in all_warnings
+ )
+ )
+ self.assertTrue(any(warn.startswith("Remember to set cuda device id") for warn in all_warnings))
+
+ def test_init_no_process_group(self):
+ """
+ Test DeviceMesh's initialization reaction to map rank to cuda device
+ when users fail to do so. We simulate the situation by setting up distributed
+ environment partially. DeviceMesh initialization takes no process group.
+ """
+ self._manual_setup()
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always") # Catch all warnings
+ device_mesh = DeviceMesh("cuda", torch.arange(self.world_size), pg=None, _validate_mesh=False)
+ if self.rank != 0:
+ # validate on rank > 0 since torch.cuda.current_device() returns 0 when the device hasn't been set.
+ all_warnings = [str(item.message) for item in w]
+ self.assertEqual(len(all_warnings), 1)
+ self.assertTrue(any(warn.startswith("Remember to set cuda device id") for warn in all_warnings))
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/dtensor/ops/test_flash_attn.py b/test/dtensor/ops/test_flash_attn.py
index 81d583b..c33faae 100644
--- a/test/dtensor/ops/test_flash_attn.py
+++ b/test/dtensor/ops/test_flash_attn.py
@@ -21,13 +21,14 @@
from torch.testing._internal.common_utils import (
run_tests,
)
-from vescale.dtensor.placement_types import Shard, Replicate, Partial
+from vescale.dtensor.placement_types import Shard
from vescale.dtensor.device_mesh import DeviceMesh
from vescale.dtensor.api import distribute_tensor
HIDDEN_DIM = 4
BSZ = 3
+
class RepeatTest(DTensorTestBase):
@property
def world_size(self):
diff --git a/test/dtensor/ops/test_random_ops.py b/test/dtensor/ops/test_random_ops.py
index 8081f2a..1210e49 100644
--- a/test/dtensor/ops/test_random_ops.py
+++ b/test/dtensor/ops/test_random_ops.py
@@ -70,7 +70,7 @@ def _run_init_op(self, init_op, *args, **kwargs):
dempty(*global_shape, device_mesh=device_mesh, placements=placements), *args, **kwargs
)
self.assertTrue(list(dtensor._spec.placements) == placements)
- self.assertEqual(dtensor.to_local(), dist_expected.to_local(), atol=0.0, rtol=0.0)
+ self.assertEqual(dtensor._local_tensor, dist_expected._local_tensor, atol=0.0, rtol=0.0)
full_tensor = dtensor.full_tensor()
self.assertEqual(full_tensor, expected_tensor, atol=0.0, rtol=0.0)
diff --git a/test/dtensor/ops/test_tensor_ops.py b/test/dtensor/ops/test_tensor_ops.py
index 523d28a..f1c24be 100644
--- a/test/dtensor/ops/test_tensor_ops.py
+++ b/test/dtensor/ops/test_tensor_ops.py
@@ -459,9 +459,12 @@ def test_expand_with_broadcast(self):
tensor = torch.randn((4,))
matrix = torch.randn((2, 3, 4))
dtensor = distribute_tensor(tensor, device_mesh, [Replicate()])
- dmatrix = distribute_tensor(matrix, device_mesh, [Shard(0)])
- dout = dtensor.expand_as(dmatrix)
- assert dout._spec.placements[0] == Shard(0), f"sharding error {dout._spec}"
+ dmatrix_shard = distribute_tensor(matrix, device_mesh, [Shard(0)])
+ dout1 = dtensor.expand_as(dmatrix_shard)
+ assert dout1._spec.placements[0] == Shard(0), f"sharding error {dout1._spec}"
+ dmatrix_partial = dmatrix_shard.redistribute(device_mesh, [Partial()])
+ dout2 = dtensor.expand_as(dmatrix_partial)
+ assert dout2._spec.placements[0] == Replicate(), f"sharding error {dout2._spec}"
@with_comms
@skip("failed")
diff --git a/test/initialize/test_defer_init.py b/test/initialize/test_defer_init.py
index 9f2a51d..b6f0df7 100644
--- a/test/initialize/test_defer_init.py
+++ b/test/initialize/test_defer_init.py
@@ -15,6 +15,7 @@
#
################################################################################
+import unittest
from common_dtensor import skip_unless_torch_gpu, with_comms, DTensorTestBase
from torch.testing._internal.common_utils import run_tests
@@ -27,7 +28,6 @@
from vescale.dtensor.placement_types import Replicate, Shard
from vescale.dtensor.dtensor import DTensor
from vescale.dtensor.device_mesh import DeviceMesh
-from vescale.dtensor import randn
from vescale.initialize.deferred_init import deferred_init, is_deferred, materialize_dtensor, materialize_dparameter
from vescale.dmodule.api import parallelize_module
from vescale.dtensor.random import manual_seed
@@ -91,43 +91,6 @@ def _assert_eq_empty(self, x: torch.Tensor, y: torch.Tensor):
self.assertTrue(x.layout == y.layout)
self.assertTrue(x.requires_grad == y.requires_grad)
- @skip_unless_torch_gpu
- @with_comms
- def test_accuracy_random2(self):
- mesh = DeviceMesh("cuda", list(range(self.world_size)))
-
- torch.use_deterministic_algorithms(True)
-
- # replicate
- torch.manual_seed(0)
- torch.cuda.manual_seed(0)
- defer_dtensor = deferred_init(torch.randn, (4, 16, 16))
- dtensor_replicate = materialize_dtensor(defer_dtensor, mesh, [Replicate()])
-
- torch.manual_seed(0)
- torch.cuda.manual_seed(0)
- expected_tensor = torch.randn((4, 16, 16), device=mesh.device_type)
-
- self.assertTrue(torch.equal(dtensor_replicate._local_tensor, expected_tensor))
-
- # shard
- torch.manual_seed(0)
- torch.cuda.manual_seed(0)
- defer_dtensor = deferred_init(torch.randn, (4, 16, 16))
- dtensor = materialize_dtensor(defer_dtensor, mesh, [Shard(1)])
-
- torch.manual_seed(0)
- torch.cuda.manual_seed(0)
- dtensor_rand = randn((4, 16, 16), device_mesh=mesh, placements=[Shard(1)])
-
- torch.manual_seed(0)
- torch.cuda.manual_seed(0)
- expected_tensor = torch.randn((4, 16 // self.world_size, 16), device=mesh.device_type)
-
- self._assert_eq_empty(dtensor._local_tensor, expected_tensor)
- self._assert_eq_empty(dtensor_rand._local_tensor, expected_tensor)
- self._assert_eq_empty(dtensor_rand._local_tensor, dtensor._local_tensor)
-
@skip_unless_torch_gpu
@with_comms
def test_local_shape(self):
@@ -215,38 +178,44 @@ def world_size(self):
@skip_unless_torch_gpu
@with_comms
+ @unittest.skip(
+ "torchdistx.deferred_init._C.is_gen_by_random_op doesn't know that nn.Linear is randomly initialized"
+ )
def test_dparameter(self):
mesh = DeviceMesh("cuda", list(range(self.world_size)))
-
- fc = deferred_init(nn.Linear, 4, 4)
- self.assertTrue(is_deferred(fc))
- self.assertTrue(is_deferred(fc.weight))
- self.assertTrue(is_fake(fc.weight))
- self.assertTrue(not is_deferred(fc.weight.data)) # NOTE
- self.assertTrue(is_fake(fc.weight.data))
- self.assertTrue(fc.weight.requires_grad)
- self.assertTrue(not fc.weight.data.requires_grad)
-
- if self.rank == 0:
- print("*** BEFORE ***", fc.weight)
-
- dparam = materialize_dparameter(fc.weight, mesh, [Shard(0)])
-
- if self.rank == 0:
- print("*** AFTER ***", dparam)
-
- # self.assertTrue(not is_deferred(dparam))
- # self.assertTrue(not is_fake(dparam))
- # self.assertTrue(not is_deferred(dparam.data))
- # self.assertTrue(not is_fake(dparam.data))
- self.assertTrue(isinstance(dparam, nn.Parameter))
- self.assertTrue(dparam.requires_grad)
- self.assertTrue(isinstance(dparam.data, DTensor))
- self.assertTrue(not dparam.data.requires_grad)
- self.assertTrue(isinstance(dparam.data._local_tensor, torch.Tensor))
- self.assertEqual(dparam.data._local_tensor.shape, (1, 4))
- self.assertTrue(dparam.data._local_tensor.is_cuda)
- self.assertTrue(not dparam.data._local_tensor.requires_grad)
+ # all_shapes = [(4, 4), (5, 9), (13, 7)]
+ # all_placesments = [[Shard(0)], [Shard(1)], [Replicate()]]
+ all_shapes = [(4, 4)]
+ all_placesments = [[Shard(0)]]
+ for shape in all_shapes:
+ for placements in all_placesments:
+ torch.cuda.manual_seed_all(0)
+ expected_fc = nn.Linear(*shape, device=self.device_type)
+ dist_fc_wgt = distribute_tensor(expected_fc.weight, mesh, placements)
+ if mesh.get_rank() == 0:
+ print(f"expected_fc.weight {expected_fc.weight}")
+
+ manual_seed(0, mesh)
+ fc = deferred_init(nn.Linear, *shape)
+ self.assertTrue(is_deferred(fc))
+ self.assertTrue(is_deferred(fc.weight))
+ self.assertTrue(is_fake(fc.weight))
+ self.assertTrue(not is_deferred(fc.weight.data)) # NOTE
+ self.assertTrue(is_fake(fc.weight.data))
+ self.assertTrue(fc.weight.requires_grad)
+ self.assertTrue(not fc.weight.data.requires_grad)
+
+ dparam = materialize_dparameter(fc.weight, mesh, placements)
+ print(f"rank {mesh.get_rank()} dparam.data {dparam.data._local_tensor}")
+ self.assertTrue(isinstance(dparam, nn.Parameter))
+ self.assertTrue(dparam.requires_grad)
+ self.assertTrue(isinstance(dparam.data, DTensor))
+ self.assertTrue(not dparam.data.requires_grad)
+ self.assertTrue(isinstance(dparam.data._local_tensor, torch.Tensor))
+
+ self.assertEqual(dparam.data._local_tensor, dist_fc_wgt._local_tensor, atol=0.0, rtol=0.0)
+ full_dparam = dparam.data.full_tensor()
+ self.assertEqual(full_dparam, expected_fc.weight, atol=0.0, rtol=0.0)
class MLP(nn.Module):
diff --git a/test/parallel/devicemesh_api/_build.py b/test/parallel/devicemesh_api/_build.py
index c269638..3771f84 100644
--- a/test/parallel/devicemesh_api/_build.py
+++ b/test/parallel/devicemesh_api/_build.py
@@ -20,7 +20,7 @@
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.dmodule.api import parallelize_module
-from vescale.devicemesh_api import veDeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
def system_setup():
@@ -65,26 +65,32 @@ def prepare_config_and_data():
return model_args, data_set
-def build_gpt_model_and_optimizer(gptconf, init_method, dp_size, tp_size, sharding_plan, use_dist_optimizer=False):
+def build_gpt_model_and_optimizer(
+ gptconf, init_method, dp_size, tp_size, sharding_plan, use_dist_optimizer=False, device_type="cuda"
+):
if init_method == "scratch":
model = GPT(gptconf).bfloat16()
else:
model = GPT.from_pretrained(init_method, dict(dropout=0.0)).bfloat16()
- device_mesh = veDeviceMesh.init_device_mesh(
- "cuda",
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type,
mesh_shape=(dp_size, tp_size),
mesh_dim_names=("DP", "TP"),
)
if tp_size > 1:
# Enable tensor parallelism
- model = parallelize_module(model, device_mesh["TP"], sharding_plan)
- else:
+ model = parallelize_module(model, VESCALE_DEVICE_MESH["TP"], sharding_plan)
+ elif device_type == "cuda":
model.to("cuda")
if dp_size > 1:
# Enable data Parallel
- dp_comm = veDeviceMesh["DP"] if veDeviceMesh.ndim > 1 else veDeviceMesh.get_data_parallel_dim_groups()
+ dp_comm = (
+ VESCALE_DEVICE_MESH["DP"]
+ if VESCALE_DEVICE_MESH.ndim > 1
+ else VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
+ )
model = DDP(
model,
data_pg_or_device_mesh=dp_comm,
@@ -98,7 +104,11 @@ def build_gpt_model_and_optimizer(gptconf, init_method, dp_size, tp_size, shardi
# Build distributed optimizer
if use_dist_optimizer and tp_size > 1:
- dp_comm = veDeviceMesh["DP"] if veDeviceMesh.ndim > 1 else veDeviceMesh.get_data_parallel_dim_groups()
+ dp_comm = (
+ VESCALE_DEVICE_MESH["DP"]
+ if VESCALE_DEVICE_MESH.ndim > 1
+ else VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
+ )
optimizer = DistributedOptimizer(
optimizer,
clip_grad=0.0,
@@ -112,4 +122,17 @@ def build_gpt_model_and_optimizer(gptconf, init_method, dp_size, tp_size, shardi
models=[model],
)
- return model, optimizer, device_mesh
+ return model, optimizer, VESCALE_DEVICE_MESH.get()
+
+
+def prepare_data(bsz, hidden_dim, dtype=torch.float, device_type="cuda"):
+ x1, y1 = torch.rand(bsz, hidden_dim, dtype=dtype), torch.rand(bsz, hidden_dim, dtype=dtype)
+ x2, y2 = torch.rand(bsz, hidden_dim, dtype=dtype), torch.rand(bsz, hidden_dim, dtype=dtype)
+ x3, y3 = torch.rand(bsz, hidden_dim, dtype=dtype), torch.rand(bsz, hidden_dim, dtype=dtype)
+ x4, y4 = torch.rand(bsz, hidden_dim, dtype=dtype), torch.rand(bsz, hidden_dim, dtype=dtype)
+ if device_type == "cuda":
+ x1, y1 = x1.cuda(), y1.cuda()
+ x2, y2 = x2.cuda(), y2.cuda()
+ x3, y3 = x3.cuda(), y3.cuda()
+ x4, y4 = x4.cuda(), y4.cuda()
+ return [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
diff --git a/test/parallel/devicemesh_api/test_api.py b/test/parallel/devicemesh_api/test_api.py
index abb023e..f1b0280 100644
--- a/test/parallel/devicemesh_api/test_api.py
+++ b/test/parallel/devicemesh_api/test_api.py
@@ -18,7 +18,7 @@
from torch.testing._internal.common_utils import run_tests
from torch.distributed import get_rank
from torch.distributed.distributed_c10d import get_process_group_ranks
-from vescale.devicemesh_api import veDeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.dtensor.device_mesh import DeviceMesh
from common_dtensor import DTensorTestBase, with_comms
@@ -33,47 +33,47 @@ def test_initialize(self):
"""
Test utilities to initialize global DeviceMesh.
"""
- # the initialized global device mesh is an outcome of initializing veDeviceMesh API
- global_device_mesh = veDeviceMesh.init_device_mesh(
+ # the initialized global device mesh is an outcome of initializing VESCALE_DEVICE_MESH API
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2),
mesh_dim_names=("DP", "TP"),
)
device_mesh = DeviceMesh(self.device_type, torch.tensor([[0, 1], [2, 3]]))
- self.assertEqual(global_device_mesh.mesh, device_mesh.mesh)
- self.assertEqual(global_device_mesh, veDeviceMesh.get())
+ self.assertEqual(VESCALE_DEVICE_MESH.get().mesh, device_mesh.mesh)
+ self.assertEqual(VESCALE_DEVICE_MESH.get(), VESCALE_DEVICE_MESH.get())
initial_config = {
"device_type": "cuda",
"mesh_shape": (2, 2),
"mesh_dim_names": ("dp", "tp"),
}
- # Taking as input parameters of veDeviceMesh.init_device_mesh, get() can initialize global DeviceMesh
- second_global_device_mesh = veDeviceMesh.get(**initial_config)
- self.assertEqual(veDeviceMesh.get().mesh, second_global_device_mesh.mesh)
+ # Taking as input parameters of VESCALE_DEVICE_MESH.init_device_mesh, get() can initialize global DeviceMesh
+ second_global_device_mesh = VESCALE_DEVICE_MESH.get(**initial_config)
+ self.assertEqual(VESCALE_DEVICE_MESH.get().mesh, second_global_device_mesh.mesh)
@with_comms
def test_basic_properties(self):
"""
Test utilities to perform basic properties inherited from upstream DeviceMesh.
"""
- # veDeviceMesh returns the global device mesh upon which is is initialized
- _ = veDeviceMesh.init_device_mesh(
+ # VESCALE_DEVICE_MESH returns the global device mesh upon which is is initialized
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2),
mesh_dim_names=("DP", "TP"),
)
- self.assertEqual(veDeviceMesh.shape, tuple([2, 2]))
- self.assertEqual(veDeviceMesh.ndim, 2)
- self.assertEqual(veDeviceMesh.size(), 4)
- self.assertEqual(veDeviceMesh.size(0), 2)
- self.assertEqual(veDeviceMesh.size(1), 2)
- self.assertFalse("PP" in veDeviceMesh._MESH_DIM_NAMES_LOOKUP)
- dp_mesh = veDeviceMesh["DP"]
+ self.assertEqual(VESCALE_DEVICE_MESH.shape, [2, 2])
+ self.assertEqual(VESCALE_DEVICE_MESH.ndim, 2)
+ self.assertEqual(VESCALE_DEVICE_MESH.size(), 4)
+ self.assertEqual(VESCALE_DEVICE_MESH.size(0), 2)
+ self.assertEqual(VESCALE_DEVICE_MESH.size(1), 2)
+ self.assertFalse("PP" in VESCALE_DEVICE_MESH._MESH_DIM_NAMES_LOOKUP)
+ dp_mesh = VESCALE_DEVICE_MESH["DP"]
dp_submesh_mesh = dp_mesh.mesh.tolist()
- tp_mesh = veDeviceMesh["TP"]
+ tp_mesh = VESCALE_DEVICE_MESH["TP"]
tp_submesh_mesh = tp_mesh.mesh.tolist()
# upstream DeviceMesh's get_coordinate utility
- strategy_coordinate = veDeviceMesh.get_coordinate()
+ strategy_coordinate = VESCALE_DEVICE_MESH.get_coordinate()
if get_rank() == 0:
self.assertEqual(dp_submesh_mesh, [0, 2])
self.assertEqual(tp_submesh_mesh, [0, 1])
@@ -88,19 +88,19 @@ def test_basic_utils(self):
"""
Test utilities to perform basic utilities with regards to local ranks and strategies.
"""
- # veDeviceMesh returns the global device mesh upon which is is initialized
- _ = veDeviceMesh.init_device_mesh(
+ # VESCALE_DEVICE_MESH returns the global device mesh upon which is is initialized
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2),
mesh_dim_names=("DP", "TP"),
)
- self.assertEqual(veDeviceMesh.get_local_rank(), get_rank())
- self.assertEqual(veDeviceMesh.get_strategy_size(0), veDeviceMesh.get_strategy_size("DP"))
- self.assertEqual(veDeviceMesh.get_strategy_size("TP"), 2)
- self.assertEqual(veDeviceMesh.lookup_rank("TP"), veDeviceMesh.get_strategy_coordinate()[1])
- self.assertEqual(veDeviceMesh.lookup_rank("DP"), veDeviceMesh.get_strategy_coordinate()[0])
- self.assertEqual(veDeviceMesh.get_strategy_coordinate(local_rank=0), [0, 0])
- self.assertEqual(veDeviceMesh.get_strategy_coordinate(local_rank=3), [1, 1])
+ self.assertEqual(VESCALE_DEVICE_MESH.get_local_rank(), get_rank())
+ self.assertEqual(VESCALE_DEVICE_MESH.get_strategy_size(0), VESCALE_DEVICE_MESH.get_strategy_size("DP"))
+ self.assertEqual(VESCALE_DEVICE_MESH.get_strategy_size("TP"), 2)
+ self.assertEqual(VESCALE_DEVICE_MESH.lookup_rank("TP"), VESCALE_DEVICE_MESH.get_strategy_coordinate()[1])
+ self.assertEqual(VESCALE_DEVICE_MESH.lookup_rank("DP"), VESCALE_DEVICE_MESH.get_strategy_coordinate()[0])
+ self.assertEqual(VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank=0), [0, 0])
+ self.assertEqual(VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank=3), [1, 1])
class TestStrategyUtil(DTensorTestBase):
@@ -113,26 +113,26 @@ def test_strategy_rank(self):
"""
Test utilities to get id of a global rank along dimensions.
"""
- # the initialized global device mesh is an outcome of initializing veDeviceMesh API
- device_mesh_one = veDeviceMesh.init_device_mesh(
+ # the initialized global device mesh is an outcome of initializing VESCALE_DEVICE_MESH API
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
- pp_rank = veDeviceMesh.get_pipeline_parallel_rank()
- dp_rank = veDeviceMesh.get_data_parallel_rank()
- tp_rank = veDeviceMesh.get_tensor_parallel_rank()
+ pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank()
+ dp_rank = VESCALE_DEVICE_MESH.get_data_parallel_rank()
+ tp_rank = VESCALE_DEVICE_MESH.get_tensor_parallel_rank()
if get_rank() == 7:
self.assertEqual((pp_rank, dp_rank, tp_rank), (1, 1, 1))
# now update a new global device mesh
- device_mesh_two = veDeviceMesh.init_device_mesh(
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(4, 1, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
- pp_rank_two = veDeviceMesh.get_pipeline_parallel_rank()
- dp_rank_two = veDeviceMesh.get_data_parallel_rank()
- tp_rank_two = veDeviceMesh.get_tensor_parallel_rank()
+ pp_rank_two = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank()
+ dp_rank_two = VESCALE_DEVICE_MESH.get_data_parallel_rank()
+ tp_rank_two = VESCALE_DEVICE_MESH.get_tensor_parallel_rank()
if get_rank() == 0:
self.assertEqual((pp_rank_two, dp_rank_two, tp_rank_two), (0, 0, 0))
if get_rank() == 7:
@@ -141,20 +141,20 @@ def test_strategy_rank(self):
@with_comms
def test_strategy_mesh(self):
"""
- Test veDeviceMesh utilities to generate sub-DeviceMesh along a parallel dimension.
+ Test VESCALE_DEVICE_MESH utilities to generate sub-DeviceMesh along a parallel dimension.
"""
- # veDeviceMesh returns the global device mesh upon which is is initialized
- _ = veDeviceMesh.init_device_mesh(
+ # VESCALE_DEVICE_MESH returns the global device mesh upon which is is initialized
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
# sub-DeviceMesh for TP view
- tp_mesh = veDeviceMesh.get_tensor_parallel_mesh()
+ tp_mesh = VESCALE_DEVICE_MESH.get_tensor_parallel_mesh()
# sub-DeviceMesh for DP view
- dp_mesh = veDeviceMesh.get_data_parallel_mesh()
+ dp_mesh = VESCALE_DEVICE_MESH.get_data_parallel_mesh()
# sub-DeviceMesh for PP view (2 stages)
- pp_mesh = veDeviceMesh.get_pipeline_parallel_mesh()
+ pp_mesh = VESCALE_DEVICE_MESH.get_pipeline_parallel_mesh()
if get_rank() == 6:
self.assertEqual(tp_mesh.mesh.tolist(), [6, 7])
self.assertEqual(dp_mesh.mesh.tolist(), [4, 6])
@@ -163,17 +163,17 @@ def test_strategy_mesh(self):
@with_comms
def test_process_groups(self):
"""
- Test veDeviceMesh utilities to query process groups in Omnistore
+ Test VESCALE_DEVICE_MESH utilities to query process groups in Omnistore
and distributed data parallel APIs.
"""
- # the initialized global device mesh is an outcome of initializing veDeviceMesh API
- device_mesh_one = veDeviceMesh.init_device_mesh(
+ # the initialized global device mesh is an outcome of initializing VESCALE_DEVICE_MESH API
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(2, 1, 4),
mesh_dim_names=("PP", "DP", "TP"),
)
- tp_process_group = veDeviceMesh.get_tensor_parallel_dim_groups()
- dp_process_group = veDeviceMesh.get_data_parallel_dim_groups()
+ tp_process_group = VESCALE_DEVICE_MESH.get_tensor_parallel_dim_groups()
+ dp_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
tp_member_ranks = get_process_group_ranks(tp_process_group)
dp_member_ranks = get_process_group_ranks(dp_process_group)
if get_rank() == 4:
@@ -183,13 +183,13 @@ def test_process_groups(self):
self.assertEqual(tp_member_ranks, [1, 5])
self.assertEqual(dp_member_ranks, [5])
# now update a new global device mesh
- device_mesh_two = veDeviceMesh.init_device_mesh(
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(4, 2),
mesh_dim_names=("DP", "TP"),
)
- tp_process_group = veDeviceMesh.get_tensor_parallel_dim_groups()
- dp_process_group = veDeviceMesh.get_data_parallel_dim_groups()
+ tp_process_group = VESCALE_DEVICE_MESH.get_tensor_parallel_dim_groups()
+ dp_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
tp_member_ranks = get_process_group_ranks(tp_process_group)
dp_member_ranks = get_process_group_ranks(dp_process_group)
if get_rank() == 4:
@@ -202,38 +202,38 @@ def test_process_groups(self):
@with_comms
def test_global_meshes(self):
"""
- Test veDeviceMesh utilities to retrieve a list of tensor parallel,
+ Test VESCALE_DEVICE_MESH utilities to retrieve a list of tensor parallel,
and pipeline parallel submeshes.
"""
- # veDeviceMesh returns the global device mesh upon which is is initialized
- device_mesh = veDeviceMesh.init_device_mesh(
+ # VESCALE_DEVICE_MESH returns the global device mesh upon which is is initialized
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(4, 1, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
- tensor_parallel_meshes = veDeviceMesh.get_global_tensor_parallel_meshes()
+ tensor_parallel_meshes = VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes()
tensor_meshes = [item.mesh.tolist() for item in tensor_parallel_meshes]
self.assertEqual(tensor_meshes, [[0, 1], [2, 3], [4, 5], [6, 7]])
- pipeline_parallel_meshes = veDeviceMesh.get_global_pipeline_parallel_meshes()
+ pipeline_parallel_meshes = VESCALE_DEVICE_MESH.get_global_pipeline_parallel_meshes()
pipeline_meshes = [item.mesh.tolist() for item in pipeline_parallel_meshes]
self.assertEqual(pipeline_meshes, [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]]])
@with_comms
def test_stage_query(self):
"""
- Test veDeviceMesh utilities to query whether current pipeline stage
+ Test VESCALE_DEVICE_MESH utilities to query whether current pipeline stage
is the first and last stage.
"""
- # veDeviceMesh returns the global device mesh upon which is is initialized
- device_mesh = veDeviceMesh.init_device_mesh(
+ # VESCALE_DEVICE_MESH returns the global device mesh upon which is is initialized
+ VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(4, 1, 2),
mesh_dim_names=("PP", "DP", "TP"),
)
- self.assertEqual(veDeviceMesh.is_first_stage(), veDeviceMesh.get_pipeline_parallel_rank() == 0)
+ self.assertEqual(VESCALE_DEVICE_MESH.is_first_stage(), VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0)
self.assertEqual(
- veDeviceMesh.is_last_stage(),
- veDeviceMesh.get_pipeline_parallel_rank() == veDeviceMesh.get_strategy_size("PP") - 1,
+ VESCALE_DEVICE_MESH.is_last_stage(),
+ VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == VESCALE_DEVICE_MESH.get_strategy_size("PP") - 1,
)
diff --git a/test/parallel/devicemesh_api/test_nano_gpt.py b/test/parallel/devicemesh_api/test_nano_gpt.py
index df9f1fc..7ae51ce 100644
--- a/test/parallel/devicemesh_api/test_nano_gpt.py
+++ b/test/parallel/devicemesh_api/test_nano_gpt.py
@@ -17,7 +17,8 @@
from torch.testing._internal.common_utils import run_tests
import torch
import vescale
-from vescale.devicemesh_api import veDeviceMesh
+import unittest
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.dtensor.placement_types import Replicate
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from parallel.devicemesh_api._build import build_gpt_model_and_optimizer, prepare_config_and_data, system_setup
@@ -39,12 +40,15 @@ def init_method(self):
# the GPT loads pretrained weights from OpenAI GPT2 repository on Huggingface
return "scratch"
+ @unittest.skip("fix me: CPU dist optimizer")
@with_comms_device("cpu")
def test_2d_dp_tp_doptim_gpt_cpu(self):
"""
Test 3-dimensional strategy demo on CPU.
When the demo runs on CPU, it uses gloo as backend.
"""
+ device_handle = getattr(torch, self.device_type, None)
+ device_handle.set_device(0)
self._test_2d_dp_tp_doptim_gpt()
@with_comms_device("cuda")
@@ -55,6 +59,7 @@ def test_2d_dp_tp_doptim_gpt_cuda(self):
"""
self._test_2d_dp_tp_doptim_gpt()
+ @unittest.skip("fix me: CPU dist optimizer")
@with_comms_device("cpu")
def test_2d_dp_tp_sp_doptim_gpt_cpu(self):
"""
@@ -103,6 +108,7 @@ def _test_2d_dp_tp_sp_doptim_gpt(self):
}
self._test_gpt(task_config)
+ @unittest.skip("fix me: CPU dist optimizer")
@with_comms_device("cpu")
def test_2d_dp_tp_base_optimizer_gpt_cpu(self):
"""
@@ -150,8 +156,8 @@ def _test_gpt(self, task_config):
optimizer.step()
def _process_data(self, x, y):
- if veDeviceMesh.get_strategy_size("TP") > 1:
- tp_mesh = veDeviceMesh.get_tensor_parallel_mesh()
+ if VESCALE_DEVICE_MESH.get_strategy_size("TP") > 1:
+ tp_mesh = VESCALE_DEVICE_MESH.get_tensor_parallel_mesh()
x = vescale.distribute_tensor(x, tp_mesh, [Replicate()])
y = vescale.distribute_tensor(y, tp_mesh, [Replicate()])
return x, y
@@ -169,6 +175,7 @@ def init_method(self):
# the GPT loads pretrained weights from OpenAI GPT2 repository on Huggingface
return "scratch"
+ @unittest.skip("fix me: CPU dist optimizer")
@with_comms_device("cpu")
def test_1d_dp_gpt_cpu(self):
"""
@@ -196,9 +203,13 @@ def _test_1d_dp_gpt(self):
model.to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Initialize global DeviceMesh
- device_mesh = veDeviceMesh.init_device_mesh("cuda", mesh_shape=(dp_size,), mesh_dim_names=("DP",))
+ VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(dp_size,), mesh_dim_names=("DP",))
# Wrap model with DDP module. Since 1D global DeviceMesh cannot slice sub-DeviceMesh. we have to rely on get_data_parallel_dim_groups()
- dp_comm = veDeviceMesh["DP"] if veDeviceMesh.ndim > 1 else veDeviceMesh.get_data_parallel_dim_groups()
+ dp_comm = (
+ VESCALE_DEVICE_MESH["DP"]
+ if VESCALE_DEVICE_MESH.ndim > 1
+ else VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
+ )
model = DDP(
model,
data_pg_or_device_mesh=dp_comm,
@@ -209,6 +220,7 @@ def _test_1d_dp_gpt(self):
# Train model
self.train(model, optimizer, data_set, use_dist_tensor=False)
+ @unittest.skip("fix me: CPU dist optimizer")
@with_comms_device("cpu")
def test_1d_tpsp_gpt_cpu(self):
"""
@@ -235,8 +247,8 @@ def _test_1d_tpsp_gpt(self):
tp_size = 2
model, data_set = self._prepare()
# Initialize global DeviceMesh
- device_mesh = veDeviceMesh.init_device_mesh("cuda", mesh_shape=(tp_size,), mesh_dim_names=("TP",))
- model = parallelize_module(model, device_mesh, nanoGPT_plan)
+ VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(tp_size,), mesh_dim_names=("TP",))
+ model = parallelize_module(model, VESCALE_DEVICE_MESH.get(), nanoGPT_plan)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Train model
self.train(model, optimizer, data_set, use_dist_tensor=True)
@@ -262,7 +274,7 @@ def _prepare(self):
def _process_data(self, x, y, use_dist_tensor=False):
if use_dist_tensor:
- tp_mesh = veDeviceMesh.get()
+ tp_mesh = VESCALE_DEVICE_MESH.get()
x = vescale.distribute_tensor(x, tp_mesh, [Replicate()])
y = vescale.distribute_tensor(y, tp_mesh, [Replicate()])
return x, y
diff --git a/vescale/checkpoint/README.md b/vescale/checkpoint/README.md
index bdd9aca..26b8d1f 100644
--- a/vescale/checkpoint/README.md
+++ b/vescale/checkpoint/README.md
@@ -1,4 +1,4 @@
-# vescale.checkpoint
+# veScale Checkpoint
`vescale.checkpoint` is an automatic distributed checkpointing system for LLM training and inference.
@@ -29,20 +29,24 @@ abstracting away the complexities of underlying details such as process rank and
- Saving checkpoint:
-```
+```python
# prepare checkpoint state for the model and optimizer
checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer }
# save the checkpoint
vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state)
```
+
- Loading checkpoint (under different world size or 3D parallelism degrees):
-```
+
+```python
# prepare checkpoint state for the model and optimizer
checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer }
# load the checkpoint
vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state)
```
-- More examples can be found under `/test/checkpoint` and `/examples`.
+- APIs can be found in: `/vescale/checkpoint/__init__.py`
+
+- More examples can be found under `/test/checkpoint/*.py` and `/examples/`
- Original examples can be found in PyTorch [Distributed Checkpoint](https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint)
\ No newline at end of file
diff --git a/vescale/checkpoint/planner/vescale/vescale_planner.py b/vescale/checkpoint/planner/vescale/vescale_planner.py
index 427c4a7..c178991 100644
--- a/vescale/checkpoint/planner/vescale/vescale_planner.py
+++ b/vescale/checkpoint/planner/vescale/vescale_planner.py
@@ -35,7 +35,7 @@
find_state_dict_object,
)
-from vescale.devicemesh_api import veDeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
logger: logging.Logger = logging.getLogger(__file__)
@@ -230,7 +230,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
op=dist.irecv,
tensor=recv_tensor,
peer=k,
- group=veDeviceMesh.get_data_parallel_dim_groups(),
+ group=VESCALE_DEVICE_MESH.get_data_parallel_dim_groups(),
)
recv_tensors[k] = recv_tensor
p2p_ops.append(recv_op)
@@ -241,7 +241,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b
op=dist.isend,
tensor=obj.local_tensor,
peer=writer_rank,
- group=veDeviceMesh.get_data_parallel_dim_groups(),
+ group=VESCALE_DEVICE_MESH.get_data_parallel_dim_groups(),
)
p2p_ops.append(send_op)
diff --git a/vescale/ddp/README.md b/vescale/ddp/README.md
index 26755b5..4a6206d 100644
--- a/vescale/ddp/README.md
+++ b/vescale/ddp/README.md
@@ -1,24 +1,37 @@
# veScale Distributed Data Parallel (DDP)
-## Overview
+## TLDR
-`Distributed Data Parallel` (`DDP`) is a distributed training strategy that partitions the input data across multiple devices, such as multiple GPUs, and replicates the model on each device. On top of this, various ZeRO features can be implemented.
+
-veScale `DDP` is primarily inherited from [Megatron-LM's DDP](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel.py). We extend the compatibility of the `DDP` implementation with our DTensor.
+## What is DDP?
-## Implementation
+`Distributed Data Parallel` (`DDP`) is the most used parallelism strategy for distributed training. It partitions the input data batch across multiple devices, replicates the model on each device, and synchronizes gradient (e.g. with `AllReduce`) in the background.
-`DDP` is a module wrapper that creates a flattened grad buffer to store the gradients produced by the model backwarding. This is achieved by adding a hook to the grad_fn of the parameters, which fill DTensor gradient outputed by PyTorch Autograd engine to the pre-allocated grad buffer. The purpose of grad buffer is to accelerate the all-reduce process for gradient updates during distributed training, as it only needs to be performed once for the entire buffer, rather than once per parameter.
+veScale `DDP` is primarily inherited from [Megatron-LM's DDP](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel.py) for its performance and compatibility with ZeRO optimizer. We extend and enhance the original DDP with extra features surrounding veScale `DTensor` and `DModule`:
-On the basis of this, there are some optimizations can be achieved:
+- conversion between `DTensor` and `Tensor` gradients
-1. Overlap gradient all-reduce with backwarding procedure. We can further split the grad buffer into several buckets. Once all gradient in a bucket is ready, we can immediately trigger the gradient all-reduce rather than waiting until the whole grad buffer is ready.
+- support nested gradient synchronization with `DModule` (for Sequence Parallel)
-2. Reduce-scatter the gradient rather than all-reduce gradient if we have a veScale `DistributedOptimizer` (a ZeRO 2+ optimizer) installed.
+- support gradient synchronization for dynamic control flow
-## Example
-Following shows a simple code. For more examples, see `/test/parallel/ddp_optim/*.py`
+## How does DDP work?
+
+`DDP` is a module (`DModule`) wrapper that creates a flattened _Gradient Buffer_ that stores the gradients produced by the model backward.
+(This is achieved by adding a hook to the `grad_fn` of the model parameters, which fills `DTensor` gradient outputed by PyTorch Autograd engine to the pre-allocated grad buffer.)
+The purpose of _Gradient Buffer_ is to both accelerate gradient synchronization and reduce memory fragmentation, as it only needs to be performed once for the entire buffer, rather than once per parameter.
+
+For extreme performance, the _Gradient Buffer_ is further divided into multiple _Bucket_s such that the backward compute and gradient synchronization of each _Bucket_ can be overlapped. As soon as all gradients in a _Bucket_ are generated, we can immediately trigger the gradient synchronization rather than waiting until the whole _Gradient Buffer_ is ready.
+
+The gradient synchronization can be either `AllReduce` or `ReduceScatter` under the DDP hood:
+
+- `AllReduce` is used when no _ZeRO_ optimizer
+
+- `ReduceScatter` is used when _ZeRO_ optimizer (e.g., `DistributedOptimizer`) exists
+
+## How to use DDP?
```python
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
@@ -31,16 +44,16 @@ mlp = MLP()
# create 2-dim DeviceMesh, the first for data-parallel, while the second for tensor-parallel.
device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]], mesh_dim_names=("DP", "TP"))
-# parallelize torch-native model into TP model (see: `/vescale/dmodule/README.md`)
-tp_mlp = parallelize_module(mlp, device_mesh["TP"], param_and_fwd_sharding_plan)
+# parallelize torch-native model into TP model
+tp_mlp = parallelize_module(mlp, device_mesh["TP"], sharding_plan)
# wrap TP model with `DDP`
dp_tp_mlp = DDP(
- # feed the paralellized module
- module=tp_mlp,
+ # feed the TP model
+ tp_mlp,
# feed DP's sub-mesh or just `device_mesh` (i.e., by default we treat the first dim of devicemesh as data-parallelism).
- data_pg_or_device_mesh=device_mesh["DP"],
- # choose whether overlap gradient all-reduce with backwarding procedure for speeding up
+ device_mesh["DP"],
+ # choose whether overlap gradient all-reduce with backward
overlap_grad_reduce=True or False,
# choose whether used `DistributedOptimizer`
# if True, `DDP` will be used with `DistributedOptimizer`, so `DDP` reduce-scatter the gradient along data-parallel ranks.
@@ -53,3 +66,8 @@ dp_tp_mlp(torch.rand(...)).sum().bakward()
# all-reduce / reduce-scatter the gradient across the DP world.
dp_tp_mlp.finish_grad_sync()
```
+
+- APIs can be found in `/vescale/ddp/distributed_data_parallel.py`
+
+- More examples can be found in `/test/parallel/ddp_optim/test_ddp.py`
+
diff --git a/vescale/devicemesh_api/README.md b/vescale/devicemesh_api/README.md
index 9b791d0..d2cc6ad 100644
--- a/vescale/devicemesh_api/README.md
+++ b/vescale/devicemesh_api/README.md
@@ -1,74 +1,70 @@
-# veScale nD Device Mesh
+# VeDeviceMesh for nD Parallelism
-## Overview
-`veDeviceMesh` is an advanced API that is built on top of PyTorch upstream’s higher level abstraction [`DeviceMesh`](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html). This API enhances the existing capabilities of DeviceMesh, enabling effective 5D parallelization strategies and easy-to-use APIs.
+## TLDR
-## Implementation
-Designed to seamlessly integrate with veScale’s Distributed `Data Parallel`, `Tensor/Sequence` (TP/SP), `DistributedOptimizer` and `Pipeline Parallel` APIs, veDeviceMesh ensures superior compatibility and performance by meticulously managing sub-DeviceMeshes and process groups. Additionally, veDeviceMesh provides user-friendly tools for querying strategy coordinates, attributes of parallel dimensions, and overall `DeviceMesh` configurations, making it a highly accessible and efficient solution for developers.
+
-veDeviceMesh embraces following user practices:
-1. “A DeviceMesh, but better”
-2. One “Mesh” fits all: users don’t need to worry about meddling with DeviceMesh and ProcessGroups’ throughout the course of training. Additionally, users make the most out of the same DeviceMesh to enable hybrid parallelization training.
-3. Easy to extend: for more refined capabilities for imminent parallelization methods in the future, veDeviceMesh provides mature APIs to extend new functionalities without breaking the semantics of communication
+(`*` is under development.)
-## Example
-Below is a simple demo of veDeviceMesh API.
+## What is VeDeviceMesh?
+
+`VeDeviceMesh (veScale Device Mesh)` is an advanced API that is built on top of PyTorch native's [`DeviceMesh`](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html). This API enhances the existing capabilities of `DeviceMesh`, enabling effective _nD parallelism strategies_, _checkpointing_, and easy-to-use APIs, with ideals below:
+
+- “A DeviceMesh, but better”
+
+- One “Mesh” fits all: users don't need to worry about meddling with DeviceMesh and ProcessGroups' throughout the course of training. Additionally, users make the most out of the same DeviceMesh to enable hybrid parallelization training.
+
+- Easy to extend: for more refined capabilities for imminent parallelization methods in the future, `VeDeviceMesh` provides mature APIs to extend new functionalities without breaking the semantics of communication
+
+## How does VeDeviceMesh work?
+
+`VeDeviceMesh` wraps around PyTorch `DeviceMesh` with APIs that seamlessly integrate with APIs of veScale's `DModule`, `DDP`, `DistributedOptimizer`, `Pipeline Parallel`, and `Checkpoint`.
+
+`VeDeviceMesh` further implements advanced features surrounding `DeviceMesh`:
+
+- rank mapping between local rank and global rank or between strategy coordinates and global rank
+
+- submesh mapping between global mesh and submeshes or between local submesh and neighbor submeshes
+
+- [in future] fault tolerance with reconfigurable meshes
+
+## How to use VeDeviceMesh?
```python
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.dmodule.api import parallelize_module
-from vescale.devicemesh_api import veDeviceMesh
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from ... import GPT
+# create torch-native model as usual
+model = GPT()
-dp_size = tp_size = 2
-data_set = ...
-sharding_plan = ...
+# initialize a VeDeviceMesh containing a global DeviceMesh with size of (2, 2)
+VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(2, 2), mesh_dim_names=("DP", "TP"))
-# load GPT-2 model from pretrained weights
-model = GPT()
+# use VeDeviceMesh to obtain global DeviceMesh's tensor parallelism view
+if VESCALE_DEVICE_MESH.get_strategy_size("TP") > 1:
+ # wrap DModule (TP/SP)
+ model = parallelize_module(model, VESCALE_DEVICE_MESH["TP"], sharding_plan, ...)
-# initialize veDeviceMesh API with a global DeviceMesh of size (2, 2)
-veDeviceMesh.init_device_mesh(
- "cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("DP", "TP"),
-)
-...
-# wrap DModule (TP/SP)
-if veDeviceMesh.get_strategy_size("TP") > 1:
- # use veDeviceMesh to obtain global DeviceMesh's tensor parallelism view
- model = parallelize_module(model, device_mesh["TP"], shardin_plan, ...)
-
-# wrap DDP module
-if veDeviceMesh.get_strategy_size("DP") > 1:
- # use veDeviceMesh to obtain ProcessGroup for data parallelism
- model = DDP(
- model,
- veDeviceMesh["DP"],
- ...
- )
-
-# build base optimizer
-optimizer = ...
-
-# build distributed optimizer
-if veDeviceMesh.get_strategy_size("DP") > 1:
- optimizer = DistributedOptimizer(
- optimizer,
- models=[model],
- )
-
-# Train model with fwd+bwd+step
+# use VeDeviceMesh to obtain global DeviceMesh's data parallelism view
+if VESCALE_DEVICE_MESH.get_strategy_size("DP") > 1:
+ # wrap DDP module
+ model = DDP(model, VESCALE_DEVICE_MESH["DP"], ...)
+
+# use VeDeviceMesh to build distributed optimizer
+if VESCALE_DEVICE_MESH.get_strategy_size("DP") > 1:
+ optimizer = DistributedOptimizer(torch.optim.Adam, models=[model], ...)
+
+# Train model
for X, Y in data_set:
- # use veDeviceMesh to tensor parallel dimension size
- tp_mesh = veDeviceMesh.get_tensor_parallel_mesh()
- ...
optimizer.zero_grad()
- _, output = model(X, Y)
- loss = ...
+ loss = model(X, Y)
loss.backward()
- ...
optimizer.step()
```
-- More examples can be found under `/test/parallel/devicemesh_api/*.py`
+- APIs can be found in `/vescale/devicemesh_api/api.py`
+
+- More examples can be found under `/test/parallel/devicemesh_api/*.py`
\ No newline at end of file
diff --git a/vescale/devicemesh_api/__init__.py b/vescale/devicemesh_api/__init__.py
index ee60e47..a9c8ea7 100644
--- a/vescale/devicemesh_api/__init__.py
+++ b/vescale/devicemesh_api/__init__.py
@@ -15,4 +15,4 @@
#
################################################################################
-from .device_mesh_api import veDeviceMesh
+from .api import VESCALE_DEVICE_MESH
diff --git a/vescale/devicemesh_api/device_mesh_api.py b/vescale/devicemesh_api/api.py
similarity index 93%
rename from vescale/devicemesh_api/device_mesh_api.py
rename to vescale/devicemesh_api/api.py
index 850d42b..a2c2eb8 100644
--- a/vescale/devicemesh_api/device_mesh_api.py
+++ b/vescale/devicemesh_api/api.py
@@ -22,7 +22,7 @@
from typing import Optional, List, Tuple, Union, Dict
from torch.distributed.distributed_c10d import ProcessGroup
-__all__ = ["veDeviceMesh"]
+__all__ = ["VESCALE_DEVICE_MESH"]
class VeDeviceMesh:
@@ -79,7 +79,7 @@ def init_device_mesh(
check_uniqueness (bool): This advanced argument is used to prevent users from spoiling global
DeviceMesh API by creating multiple copies in a large code repository.
- Set to True to allow veDeviceMesh API to check the "global device mesh" is only initialized once.
+ Set to True to allow VESCALE_DEVICE_MESH API to check the "global device mesh" is only initialized once.
Otherwise, users can create as many DeviceMeshes as they want just like with upstream Devicemesh.
Returns:
@@ -90,13 +90,13 @@ def init_device_mesh(
Example:
>>> # xdoctest: +SKIP
- >>> from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
+ >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH
>>>
- >>> # Example 1: create a one-dimensional DeviceMesh
- >>> mesh_1d = veDeviceMesh.init_device_mesh("cuda", mesh_shape=(8,))
+ >>> # Example 1: initialize the global DeviceMesh as a one-dimensional DeviceMesh
+ >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(8,))
>>>
- >>> # Example 2: create a two-dimensional DeviceMesh
- >>> mesh_2d = veDeviceMesh.init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
+ >>> # Example 2: re-initialize the global DeviceMesh as a two-dimensional DeviceMesh
+ >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
Limitation: we currently only support fixed sized DeviceMesh with 1 to 3 dimensions. We will loosen this constraint in future.
"""
@@ -199,15 +199,15 @@ def get_strategy_coordinate(self, local_rank=None) -> List[int]:
Coordinate of local rank mapped to the global DeviceMesh's parallel dimensions.
Example:
- >>> from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
+ >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH
>>> dp_size, tp_size = 2, 2
>>> # Initialize global device mesh of (dp_size=2, tp_size=2)
- >>> _ = veDeviceMesh.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"))
+ >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"))
>>> local_rank = torch.distributed.get_rank() # local_rank is 0
0
- >>> veDeviceMesh.get_strategy_coordinate(local_rank)
+ >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank)
[0, 0]
- >>> veDeviceMesh.get_strategy_coordinate(3)
+ >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(3)
[1, 1]
"""
assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!"
@@ -230,19 +230,19 @@ def lookup_rank(self, dim: Union[int, str]) -> int:
Specified parallel strategy 'rank' of a global rank.
Example:
- >>> from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
+ >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH
>>> dp_size, tp_size = 2, 2
>>> # Initialize global device mesh of (dp_size=2, tp_size=2)
- >>> _ = veDeviceMesh.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"))
+ >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"))
>>> local_rank = torch.distributed.get_rank() # local_rank = 0
0
- >>> veDeviceMesh.get_strategy_coordinate(local_rank)
+ >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank)
[0, 0]
>>> index = 1
- >>> veDeviceMesh.lookup_rank(index) # local_rank is 0
+ >>> VESCALE_DEVICE_MESH.lookup_rank(index) # local_rank is 0
0
>>> dim_name = "DP"
- >>> veDeviceMesh.lookup_rank(dim_name) # local_rank is 0
+ >>> VESCALE_DEVICE_MESH.lookup_rank(dim_name) # local_rank is 0
0
"""
assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!"
@@ -472,4 +472,4 @@ def shape(self) -> Tuple[int, ...]:
return tuple(device_mesh.mesh.shape)
-veDeviceMesh = VeDeviceMesh()
+VESCALE_DEVICE_MESH = VeDeviceMesh()
diff --git a/vescale/dmodule/README.md b/vescale/dmodule/README.md
index 4217973..4073a45 100644
--- a/vescale/dmodule/README.md
+++ b/vescale/dmodule/README.md
@@ -1,4 +1,8 @@
-# veScale DModule (Distributed Module)
+# veScale DModule for Tensor Parallel & Sequence Parallel
+
+## TLDR
+
+
## Why veScale DModule?
@@ -20,7 +24,7 @@
- [experimental] given by "automatical plan generation" of veScale
- handles gradient synchronization automatically in backward
- support deferred initialization with `deferred_init()` (i.e., initialize with `Fake` Tensor without allocating memory, then shard Fake Tensor with TP, and then materialize only a shard of `Tensor` on device)
- - support third-party plug-in Module (e.g. APEX)
+ - support third-party plug-in Module (e.g. `APEX`)
- provide patch interface for customized Module-level hacking
- extend to optional DP, optional FSDP, and optional EP (in the future)
- provide debuggability for easy dumping, printing, listing, etc.
@@ -29,7 +33,7 @@
- veScale `DModule` is inspired by PyTorch's [`parallelize_module`](https://pytorch.org/docs/stable/_modules/torch/distributed/tensor/parallel/api.html#parallelize_module), but is developed with explicit Module-level abstraction with complete features for our production usage.
-- veScale `DModule` extends PyTorch `parallelize_module` with extra features as below (i.e., major differences from PyTorch-v2.2.0):
+- veScale `DModule` extends PyTorch `parallelize_module` with extra features as below:
- nD Tensor Parallelism
- Sequence Parallelism
- auto gradient synchronization
@@ -38,7 +42,7 @@
- module-level patch interface
- [experimental] automatical plan generation
-## How to use veScale DModule ``manually''?
+## How to use veScale DModule manually?
- Example of `MLP`:
@@ -64,14 +68,14 @@
# parallelize model into DModule with "maunal plans"
dmlp = parallelize_module(mlp,
DeviceMesh("cuda", [0, 1, 2, 3]),
- {
- "parameter" : {
+ { # sharding plan
+ "parameter" : { # appoint "which param" with what [placements]
"fc1.weight": [Shard(0)],
"fc1.bias": [Shard(0)],
"fc2.weight": [Shard(1)],
"fc2.bias": [Replicate()],
},
- "forward" : {
+ "forward" : { # appoint "which activation" with what [placements]
"fc1.input": [[Replicate()]], # change to Shard() for SP/DP
"fc2.output": [[Replicate()]],
}
@@ -85,6 +89,6 @@
```
-- More details can be found in `/vescale/dmodule/api.py`
+- APIs can be found in `/vescale/dmodule/api.py`
-- More examples can be found under `/test/dmodule/*.py`
+- More examples can be found under `/test/dmodule/*.py`
\ No newline at end of file
diff --git a/vescale/dmp/README.md b/vescale/dmp/README.md
new file mode 100644
index 0000000..7b3ba97
--- /dev/null
+++ b/vescale/dmp/README.md
@@ -0,0 +1,3 @@
+# Auto TP & SP Plan
+
+# Coming Soon
\ No newline at end of file
diff --git a/vescale/dtensor/README.md b/vescale/dtensor/README.md
index ae5679f..c54ecdb 100644
--- a/vescale/dtensor/README.md
+++ b/vescale/dtensor/README.md
@@ -1,5 +1,9 @@
# DTensor (Distributed Tensor)
+## TLDR
+
+
+
## Why DTensor?
- `torch.Tensor` lacks the semantic of being distributed across multiple devices and running distributed operators
@@ -12,71 +16,63 @@
- `DTensor` transparently handles all distributed logic under the hood (sharded storage on each device, the collective communication among devices, and the operator kernel split across devices)
-- `DTensor` is implemented by a wrapper class on `torch.tensor` with a meta data `DTensorSpec` describing:
+- `DTensor` is implemented by a wrapper class on `torch.Tensor` with a meta data `DTensorSpec` describing:
- which multiple devices (`DeviceMesh`) is distributed upon
- - how is `DTensor` placed (`placements`) on the `DeviceMesh`; there are three main `placements`:
-
- - `Replicate`: `DTensor` is replicated on the `DeviceMesh`
- - `Shard`: `DTensor` is sharded on the `DeviceMesh`
- - `Partial`: `DTensor` is a partial product on the `DeviceMesh` with pending sum (`AllReduce`) to be a total product
-
- - what is the global tensor shape & stride (`tensor_meta`) of this `DTensor`
-
-- `DTensor` computation is implemented by `ShardingPropagator` which propagates placements from input to output for each operator with pre-registered sharding rules and strategies
+ - it can be 1D mesh of two GPUs: `DeviceMesh("cuda", [0, 1])`
+ - it can be 2D mesh of four GPUs: `DeviceMesh("cuda", [[0, 1], [2, 3]])`
-## How to use DTensor ``manually''?
-
-- Example of `matmul`:
+ - how is `DTensor` placed (`Placement`) on the `DeviceMesh`:
+
+ - there are three main `Placement`:
- ``` python
- # create a four-device mesh
- device_mesh = DeviceMesh("cuda", [0, 1, 2, 3])
+ - `Shard()`: `DTensor`'s `` is sharded on the `DeviceMesh`
+ - `Replicate`: `DTensor` is replicated on the `DeviceMesh`
+ - `Partial`: `DTensor` is a partial product on the `DeviceMesh` with pending sum (`AllReduce`) to be a total product
+
+ - where a list of `Placement` is needed to define the `placements` of a `DTensor`:
- # single device matmul
- t1 = torch.ones(12, 8, device="cuda")
- t2 = torch.ones(8, 16, device="cuda")
- t3 = torch.mm(t1, t2)
+ - `placements = [Shard(1)]` means `DTensor`'s tensor dim #1 is sharded along `DeviceMesh`'s dim #0 (i.e., the #0 element in the list)
- # multiple device matmul
- dt1 = distribute_tensor(t1, device_mesh, [Shard(dim=1)]) # colwise shard t1 on device mesh
- dt2 = distribute_tensor(t2, device_mesh, [Shard(dim=0)]) # rowwise shard t2 on device mesh
- dt3 = torch.mm(dt1, dt2)
- assert isinstance(dt3, DTensor)
- assert dt3.placements[0].is_partial() # product t3 is partial sharded on device mesh
- dt4 = dt3.redistribute(device_mesh, [Replicate()]) # reshard t3 with allreduce to replicate
+ - `placements = [Shard(1), Shard(0)]` means `DTensor`'s tensor dim #1 is sharded along `DeviceMesh`'s dim #0 and `DTensor`'s tensor dim #0 is sharded along `DeviceMesh`'s dim #1
- # match DTensor and Tensor result
- assert torch.equal(dt4.to_local(), t3)
- ```
+ - `placements = [Shard(1), Replicate()]` means `DTensor`'s tensor dim #1 is sharded along `DeviceMesh`'s dim #0 and `DTensor`'s rest tensor dim #0 is replicated along `DeviceMesh`'s dim #1
-- More examples can be found under `/test/dtensor/*/*.py`
+ - what is the global tensor shape & stride (`TensorMeta`) of this `DTensor`
-- Original examples can be found in PyTorch [DTensor](https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor).
+- `DTensor` operators (e.g., `torch.add`) are implemented by `ShardingPropagator` which propagates `placements` from input to output for each operator with pre-registered sharding rules and strategies
-## What is veScale DTensor?
+## What is veScale DTensor? How's different from PyTorch DTensor?
-- veScale is a PyTorch native framework rooted in PyTorch DTensor
+- veScale is a PyTorch-native framework rooted in _**PyTorch DTensor**_
-- veScale DTensor has been and will be synchronizing with PyTorch DTensor
+- _**veScale DTensor**_ extends and enhances the _**PyTorch DTensor**_ for our production standard with extra features as below:
-- veScale DTensor shares the majority of code of PyTorch DTensor, but extends it with extra features as below (i.e., major differences from PyTorch DTensor v2.2.0) for our production usage:
+ - enabled "correct random ops" under abitrary sharding and uneven sharding, i.e., always guarantee random op sharded on multi device is equal to random op on a single device.
- - enabled DTensor support for third-party plug-in ops (e.g., APEX) by unleashing `DTensor.data_ptr` and handling asynchronous collective tensors (e.g., in `from_local`, `to_local`, `redistribute`)
+ - enabled DTensor support for third-party plug-in ops (e.g., `APEX`) by unleashing `DTensor.data_ptr` and handling asynchronous collective tensors (e.g., in `from_local`, `to_local`, `redistribute`)
- make implicit `_Partial` to explicit `Partial` placement for optimized initialization, output, and checkpoint (with an extra dispatch mode)
- - enabled DTensor ops that were not implemented in PyTorch:
- - `argmax` and `argmin`
+ - enabled DTensor ops that were not implemented in PyTorch for forward or/and backward:
+ - `argmax`
+ - `argmin`
- `topk`
- `_unique2`
- - `scatter_` and `scatter`
+ - `scatter_`
+ - `scatter`
- `select`
- `alias`
- - `index_put_` and `index_put`
+ - `index_put_`
+ - `index_put`
+ - `index_add_`
- `_scaled_dot_product_flash_attention`
- `_scaled_dot_product_efficient_attention`
+ - `expand_as`
+ - `one_hot`
+ - `where`
+ - `Embedding` in vocabular parallel
- support uneven sharding in conversion between `DTensor` and `torch.Tensor`
@@ -105,10 +101,45 @@
- [experimental] developed `InterleavedShard` placement to support merged QKV in MHA
+ - [experimental] extreme performance with C++ DTensor
+
+ - [experimental] extreme performance with dispatching-free DTensor
+
+## How to use veScale DTensor manually?
+
+- Example of `matmul`:
+
+ ``` python
+ # create a four-device mesh
+ device_mesh = DeviceMesh("cuda", [0, 1, 2, 3])
+
+ # single device matmul
+ t1 = torch.ones(12, 8, device="cuda")
+ t2 = torch.ones(8, 16, device="cuda")
+ t3 = torch.mm(t1, t2)
+
+ # multiple device matmul
+ dt1 = distribute_tensor(t1, device_mesh, [Shard(dim=1)]) # colwise shard (tensor dim 1) t1 along device mesh's dim 0
+ dt2 = distribute_tensor(t2, device_mesh, [Shard(dim=0)]) # rowwise shard (tensor dim 0) t2 along device mesh's dim 0
+ dt3 = torch.mm(dt1, dt2)
+ assert isinstance(dt3, DTensor)
+ assert dt3.placements[0].is_partial() # product t3 is partial sharded on device mesh
+ dt4 = dt3.redistribute(device_mesh, [Replicate()]) # reshard t3 with allreduce to replicate
+
+ # match DTensor and Tensor result
+ assert torch.equal(dt4.to_local(), t3)
+ ```
+
+- APIs can be found under `/vescale/dtensor/api.py`
+
+- More examples can be found under `/test/dtensor/*/*.py`
+
+- Original examples can be found in PyTorch [DTensor](https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor).
+
## What if encountering an operator that is not supported by DTensor yet?
--- Register DTensor "Ops" for Sharding Propagation!
+-- _Register DTensor "Ops" for Sharding Propagation!_
### Why register DTensor Ops for sharding propagation?
@@ -222,18 +253,33 @@ def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
```
-## How to generate random numbers in DTensor as if it's from a single GPU
+## How to generate random numbers in DTensor as if it's from a single GPU?
+
+### Motivation
+
+Ideally, DTensor should provide single-device abstraction even for random ops (e.g. `dtensor.randn`, `nn.Dropout`, and ``), i.e., random value generated on single device should be identical to collective of random shard on multiple devices.
+
+
+### Problem
-In veScale, we introduce a `ThreadBasedRNGTracker` for managing the RNG states across different GPUs.
-As a result, we can generate random DTensors that are identical to the ones from single GPUs.
-To use the feature, build and install a patched pytorch and set the environment variable `VESCALE_SINGLE_DEVICE_RAND=1`.
+PyTorch DTensor (i.e., `OffsetBasedRNGTracker`) does not produce the random values on multiple devices identical to single GPU execution for random operators (e.g. `dtensor.randn`, `nn.Dropout`, and ``).
+The key problem lies in that the CUDA random numbers are not generated "sequentially" and cannot be simply offsetted by rank ids, but instead are generated "simultaneously" by multiple CUDA threads and only be sharded by CUDA thread ids!
+
+### Solution
+
+In veScale, we introduce a `ThreadBasedRNGTracker` for correcting the RNG states across different GPUs, enabling generation of correct DTensor that are identical to the ones from single GPUs for any random ops.
+
+To use the feature, build and install a patched PyTorch of veScale and set the environment variable `VESCALE_SINGLE_DEVICE_RAND=1`.
+
+### Details
+
+Whenever invoking a randomized operation on a DTensor, `ThreadBasedRNGTracker` passes its sharding info to the C++/Cuda side of PyTorch through the RNG state.
+This resolves the issue that PyTorch DTensor's `OffsetBasedRNGTracker` does not produce the output identical to single GPU executions.
-Whenever invoking a randomized operation on a DTensor, `ThreadBasedRNGTracker` passes its sharding info to the C++/Cuda side of pytorch through the RNG state.
-This resolves the issue that OffsetBasedRNGTracker does not produce the output identical to single GPU executions.
For example, consider generating `x = torch.rand(4)` given the current random seed and
a global offset. In Cuda's RNG implementation, random numbers are accessed via a triple
-(seed, thread id, offset).
+`(seed, thread id, offset)`.
On a single GPU, 4 GPU threads is created and the i-th thread fills the entry `x[i]`
with `rand(seed, i, offset)`. That is, we have
@@ -241,16 +287,14 @@ with `rand(seed, i, offset)`. That is, we have
| Thread 0 | Thread 1 | Thread 2 | Thread 3 |
x = | rand(0, offset) | rand(1, offset) | rand(2, offset) | rand(3, offset) |
```
-After the execution of torch.rand(4), the global offset increments by 4, which is the
+After the execution of `torch.rand(4)`, the global offset increments by 4, which is the
granularity of cuda's RNG offsets.
The global offset increments by the size of the randomness used in each thread, rounded
up to the nearest multiple of 4. For instance, if 1000 GPU threads is used to generate
-7000 random numbers, each thread takes 7 random numbers from Cuda RNG and the global offset
-increases by 8 afterward.
+7000 random numbers, each thread takes 7 random numbers from Cuda RNG and the global offset increases by 8 afterward.
-However, using OffsetBasedRNGTracker along with an un-patched pytorch, it outputs a
-different tensor given 2 GPUs.
+However, using `OffsetBasedRNGTracker`, it outputs a different tensor given 2 GPUs.
```
| GPU 0 | GPU 1 |
| Thread 0 of GPU 0 | Thread 1 of GPU 0 | Thread 0 of GPU 1 | Thread 1 of GPU 1 |
@@ -267,4 +311,10 @@ x = | rand(seed, 0, offset) | rand(seed, 1, offset) | rand(seed, 2, offset) | ra
```
And after the execution, the global offset should increment by 4.
This can be done if we pass the sharding info into Cuda functions that generate these
-outputs.
\ No newline at end of file
+outputs.
+
+
+## Acknowledgement
+
+We would like to acknowledge the assistance of and collaboration with
+the [PyTorch DTensor team](https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor).
\ No newline at end of file
diff --git a/vescale/dtensor/__init__.py b/vescale/dtensor/__init__.py
index f8e012b..2c2763d 100644
--- a/vescale/dtensor/__init__.py
+++ b/vescale/dtensor/__init__.py
@@ -334,7 +334,7 @@ def rand(
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from an uniform distribution
- with mean 0 and variance 1. The global shape of the tensor is defined by the variable
+ on the interval [0, 1]. The global shape of the tensor is defined by the variable
argument ``size``.
It will be on device type of device mesh; presetting default cuda rank is a must.
diff --git a/vescale/dtensor/device_mesh.py b/vescale/dtensor/device_mesh.py
index 9e03727..c9647b9 100644
--- a/vescale/dtensor/device_mesh.py
+++ b/vescale/dtensor/device_mesh.py
@@ -266,6 +266,34 @@ def __init__(
# step 1: try to create default world pg.
if pg is None:
pg = self._get_or_create_default_group()
+ else:
+ # TODO: this logic only applies when device_type is cuda
+ pg_world_size = get_world_size(group=pg)
+ device_handle = _get_device_handle(self.device_type)
+ num_devices_per_host = device_handle.device_count()
+ if pg_world_size > num_devices_per_host and pg_world_size % num_devices_per_host != 0:
+ raise RuntimeError(
+ f"DeviceMesh only support homogeneous hardware, but found "
+ f"{pg_world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
+ )
+ if self.device_type == "cuda":
+
+ def _get_current_device():
+ try:
+ if torch.cuda.is_available():
+ return torch.cuda.current_device()
+ else:
+ return None
+ except AssertionError as e:
+ return None
+
+ device_handle = _get_device_handle(self.device_type)
+ num_devices_per_host = device_handle.device_count()
+ local_rank = get_rank() % num_devices_per_host
+ if local_rank != _get_current_device():
+ warnings.warn("Remember to set cuda device id to local rank!!!")
+ device_handle = _get_device_handle(self.device_type)
+ device_handle.set_device(local_rank)
# step 2: validate the mesh before following usage.
if _validate_mesh:
diff --git a/vescale/dtensor/ops/tensor_ops.py b/vescale/dtensor/ops/tensor_ops.py
index 9d4646a..c28cc5f 100644
--- a/vescale/dtensor/ops/tensor_ops.py
+++ b/vescale/dtensor/ops/tensor_ops.py
@@ -940,9 +940,8 @@ def index_add_rule(op_schema: OpSchema) -> OutputSharding:
if not index_spec.is_replicated():
raise RuntimeError("index must be replicate for index_add op")
- if src_spec.sums or input_spec.sums:
- # TODO: maybe we should allow partial here.
- raise NotImplementedError("src and input can not be partial for index_add op")
+ if src_spec.sums != input_spec.sums:
+ raise NotImplementedError("src and input should be both partial or non-partial for index_add op")
if src_spec.ndim != input_spec.ndim:
raise RuntimeError("invalid index_add op detected")
diff --git a/vescale/dtensor/ops/view_ops.py b/vescale/dtensor/ops/view_ops.py
index 36a3475..ace7730 100644
--- a/vescale/dtensor/ops/view_ops.py
+++ b/vescale/dtensor/ops/view_ops.py
@@ -681,7 +681,8 @@ def expand_as_prop(op_schema: OpSchema) -> OutputSharding:
op=aten.expand.default, args_schema=(source, tuple(global_out_shape)), kwargs_schema=op_schema.kwargs_schema
)
expand_sharding_out = _reshape_prop(new_op_schema, ops[Tensor.expand])
- expand_sharding_out.output_spec.placements = dst.placements
+ if any(p.is_shard() for p in dst.placements):
+ expand_sharding_out.output_spec.placements = dst.placements
expand_sharding_out.needs_redistribute = False
expand_sharding_out.suggested_schema = None
return expand_sharding_out
diff --git a/vescale/dtensor/random.py b/vescale/dtensor/random.py
index 206d7cf..d2e5767 100644
--- a/vescale/dtensor/random.py
+++ b/vescale/dtensor/random.py
@@ -11,6 +11,7 @@
import contextlib
import warnings
from typing import Dict, List, Optional, Tuple
+from math import prod
import os
import torch
@@ -19,6 +20,7 @@
from vescale.dtensor.device_mesh import _get_device_handle, DeviceMesh
from vescale.dtensor.placement_types import DTensorSpec, Shard
+from vescale.dtensor._utils import compute_local_shape_and_global_offset
_rng_tracker: Optional["RNGStateTracker"] = None
@@ -318,8 +320,6 @@ def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
"""
dtensor_shape = spec.shape
- from vescale.dtensor.ops.utils import prod
-
numel = prod(dtensor_shape)
# pytorch: offset must be multiple of 4
# source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
@@ -390,6 +390,14 @@ class ThreadBasedRNGTracker(OffsetBasedRNGTracker):
def __init__(self, device_type: str = "cuda"):
super().__init__(device_type)
+ # source: aten/src/ATen/native/cuda/DistributionTemplates.h
+ self.block_size = 256
+ self.unroll = 4
+ props = torch.cuda.get_device_properties(torch.cuda.current_device())
+ # For example, in an A100: props.max_threads_per_multi_processor = 2048, props.multi_processor_count = 108
+ self.max_threads_per_multi_processor = props.max_threads_per_multi_processor
+ self.blocks_per_sm = self.max_threads_per_multi_processor // self.block_size
+ self.max_grid = props.multi_processor_count * self.blocks_per_sm
def get_offset(self, name: str) -> int:
if name not in self.rng_states:
@@ -423,29 +431,17 @@ def get_sharding_spec(self, name: str) -> Tuple[Tuple[int, ...], Tuple[int, ...]
)
def set_sharding_spec(
- self, name: str, sharding_spec: Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]
+ self,
+ name: str,
+ sharding_spec: Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]],
+ offset: int,
) -> None:
if name not in self.rng_states:
raise RuntimeError(f"{self.__class__.__name__} does not have random state for {name}")
- local_shape, global_offset, global_shape, global_strides = sharding_spec
-
seed_tensor = (self.rng_states[name])[0:8]
- offset_tensor = (self.rng_states[name])[8:16]
- local_shape_tensor = torch.tensor(local_shape).view(torch.uint8)
- global_offset_tensor = torch.tensor(global_offset).view(torch.uint8)
- global_shape_tensor = torch.tensor(global_shape).view(torch.uint8)
- global_strides_tensor = torch.tensor(global_strides).view(torch.uint8)
- self.rng_states[name] = torch.cat(
- [
- seed_tensor,
- offset_tensor,
- local_shape_tensor,
- global_offset_tensor,
- global_shape_tensor,
- global_strides_tensor,
- ]
- )
+ spec_tensor = torch.tensor(sum(sharding_spec, start=(offset,))).view(torch.uint8)
+ self.rng_states[name] = torch.cat([seed_tensor, spec_tensor])
@contextlib.contextmanager
def _distribute_region(self, spec: DTensorSpec):
@@ -457,7 +453,7 @@ def _distribute_region(self, spec: DTensorSpec):
)
if self.distribute_region_enabled:
old_offset = self.get_offset("parallel-rng")
- self._set_pre_op_sharding_spec(spec)
+ self._set_pre_op_sharding_spec(spec, old_offset)
with torch.random.fork_rng(self._devices, device_type=self._device_type):
self._device_handle.set_rng_state(self.rng_states["parallel-rng"])
try:
@@ -468,7 +464,7 @@ def _distribute_region(self, spec: DTensorSpec):
else:
yield
- def _set_pre_op_sharding_spec(self, spec: DTensorSpec) -> None:
+ def _set_pre_op_sharding_spec(self, spec: DTensorSpec, old_offset: int) -> None:
"""Passing the DTensor sharding info via Cuda RNG State. Later on,
each GPU thread can use the info to deduce the correct thread id and
offset when generating an entry of a DTensor.
@@ -483,19 +479,22 @@ def _set_pre_op_sharding_spec(self, spec: DTensorSpec) -> None:
.. warning::
Note that, current implementation does not consider DTensor's continguity.
"""
- global_shape = spec.shape
- mesh = spec.mesh
-
- from vescale.dtensor._utils import compute_local_shape_and_global_offset
+ if spec.num_shards > 0:
+ global_shape = spec.shape
+ mesh = spec.mesh
- local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, spec.placements)
- global_strides = spec.tensor_meta.stride
+ local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, spec.placements)
+ global_strides = spec.tensor_meta.stride
- if (local_shape, global_offset) == ((), ()): # a out-of-mesh rank
- local_shape = tuple([0] * len(global_shape))
- global_offset = tuple([0] * len(global_shape))
+ if (local_shape, global_offset) == ((), ()): # a out-of-mesh rank
+ local_shape = tuple([0] * len(global_shape))
+ global_offset = tuple([0] * len(global_shape))
- self.set_sharding_spec("parallel-rng", (local_shape, global_offset, global_shape, global_strides))
+ self.set_sharding_spec(
+ "parallel-rng",
+ (local_shape, global_offset, global_shape, global_strides),
+ old_offset,
+ )
def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
"""Set the RNG state as the DTensor operation is executed on a single GPU. This
@@ -511,20 +510,12 @@ def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
"""
dtensor_shape = spec.shape
- from vescale.dtensor.ops.utils import prod
-
numel = prod(dtensor_shape)
- # source: aten/src/ATen/native/cuda/DistributionTemplates.h
- block_size = 256
- unroll = 4
- props = torch.cuda.get_device_properties(spec.mesh.device_type)
- # For example, in an A100: props.max_threads_per_multi_processor = 2048, props.multi_processor_count = 108
- blocks_per_sm = props.max_threads_per_multi_processor // block_size
- grid_x = min(props.multi_processor_count * blocks_per_sm, (numel + block_size - 1) // block_size)
- offset_incr = ((numel - 1) // (block_size * grid_x * unroll) + 1) * unroll
- self.set_offset("parallel-rng", old_offset + offset_incr)
- self.set_sharding_spec("parallel-rng", ((), (), (), ()))
+ grid_x = min(self.max_grid, (numel + self.block_size - 1) // self.block_size)
+ offset_incr = ((numel - 1) // (self.block_size * grid_x * self.unroll) + 1) * self.unroll
+ new_offset = old_offset + offset_incr
+ self.set_sharding_spec("parallel-rng", ((), (), (), ()), new_offset)
class TensorParallelRNGTracker(RNGStateTracker):
diff --git a/vescale/dtensor/redistribute.py b/vescale/dtensor/redistribute.py
index 6db1e34..7f22ba1 100644
--- a/vescale/dtensor/redistribute.py
+++ b/vescale/dtensor/redistribute.py
@@ -449,7 +449,19 @@ def forward( # type: ignore[override]
# Early return the original DTensor if the placements are the same.
if input._spec.placements == placements:
- return input
+ # FIXME: To avoid view(). There are several hidden dangers here:
+ # - The change of the tensor wrapper may cause the failure of the tensor's hooks.
+ # - Modifying the tensor may change the result of is_param of parameters.
+ # - Dynamically modifying the computation graph may cause problems with autograd.
+ return dtensor.DTensor(
+ input._local_tensor,
+ device_mesh,
+ placements,
+ shape=input.shape,
+ dtype=input.dtype,
+ requires_grad=input.requires_grad,
+ stride=input.stride(),
+ )
target_spec = DTensorSpec(device_mesh, placements, tensor_meta=input._spec.tensor_meta)
diff --git a/vescale/initialize/deferred_init.py b/vescale/initialize/deferred_init.py
index 1dd4c2e..cb34f4e 100644
--- a/vescale/initialize/deferred_init.py
+++ b/vescale/initialize/deferred_init.py
@@ -189,14 +189,14 @@ def materialize_dparameter(
torch_device = torch.device(device)
# materialize local tensor
if _C.is_gen_by_random_op(param):
- tensor_meta = TensorMeta(global_shape, (0,), param.data.dtype)
+ tensor_meta = TensorMeta(global_shape, torch_stride, param.data.dtype)
spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta)
assert random.is_rng_supported_mesh(
device_mesh
), "currently, random DTensor only support cuda/cuda=like device!"
if not random._rng_tracker:
- random._rng_tracker = random.OffsetBasedRNGTracker()
+ random._rng_tracker = random.init_vescale_rng_tracker()
assert random._rng_tracker is not None
with random._rng_tracker._distribute_region(spec):
param = _C.materialize_tensor_with_local_shape(param, local_shape, torch_device)
diff --git a/vescale/optim/README.md b/vescale/optim/README.md
index 5858d33..5582bc1 100644
--- a/vescale/optim/README.md
+++ b/vescale/optim/README.md
@@ -1,55 +1,51 @@
-# veScale Optimizers
+# veScale Optimizer Parallel
-## Overview
-
-In distributed training, optimizers also need to be adjusted accordingly. We provide two options:
+## TLDR
-### `BasicOptimizer`
+
-A simple optimizer warpper plus some utilities for distributed training, such as recover flattened gradient from `DDP` and trigger gradient all-reduce for LayerNorm (or some other similar) blocks in Sequence Parallel. `BasicOptimizer` is not a ZeRO optimizer.
-
-### `DistributedOptimizer`
+## Overview
-A "ZeRO 2+" optimizer. Simliar to `DDP`, veScale `DistributedOptimizer` is primarily inherited from [Megatron-LM's DistributedOptimizer](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/optimizer/distrib_optimizer.py). We extend compatibility of its implementation with our DTensor.
+In veScale, we provide two optimizers for _Optimizer Parallel_:
-## Implementation
+- `DistributedOptimizer`
-### `BasicOptimizer`
+- `BasicOptimizer`
-`BasicOptimizer`'s implementation is quite simple. See the docstring of `BasicOptimizer` at `/vescale/optim/base_optimizer.py`.
+### veScale `DistributedOptimizer`
-### `DistributedOptimizer`
+#### What is it?
-`DistributedOptimizer`'s implementation is complex. Different from `DDP`, in `DistributedOptimizer`, the model parameters and gradients are further split. Each DP rank only obtains the corresponding gradient, updates the corresponding parameters, maintaining the corresponding optimizer states. Therefore, a typical optimizer initialization and step process of `DistributedOptimizer` includes the following stages:
+`DistributedOptimizer` is a _ZeRO 2+_ optimizer. Similar to the original _ZeRO2_, it parallelizes model gradient and optimizer states along _Data Parallel_ dimension. Differently, it further parallelizes model parameters virtually but not physically.
-1. At initialzation, model parameters need to be split across all DP ranks, but this is not a `real` split. Each DP rank actually owns a partial view of the original model parameters. Note that this split does not respect parameter boundaries, which means that a parameter could be split into two halves and belong to two DP ranks. Therefore, a complex mapping between the dp-sharded parameters and the original parameters needs to be established, which is mostly done in the init function. At last, we replace the optimizer's param_groups with the dp-sharded parameter.
+`DistributedOptimizer` is primarily inherited from [Megatron-LM's DistributedOptimizer](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/optimizer/distrib_optimizer.py) for its performance and mostly due to the lacking of _ZeRO2_ optimizer in native PyTorch. We extend and enhance `DistributedOptimizer` with extra features:
-2. At step, copy `main_grad` attached at original parameter by `DDP` to the dp-sharded parameters.
+- convert between `Tensor` and `DTensor`
-3. Run `optimizer.step()`.
+- support online resharding of optimzier state
-4. Copy updated dp-sharded parameters to a specific param buffer. To avoid the overhead of sending each parameter individually through an allgather operation, we reused the gradient buffer's space as a parameter buffer. This allow us to store the updated parameters temporarily before they are all-gathered back to their original form before the next forward execution. This strategy helped us save GPU memory. And this introduce a further optimization.
+#### How does it work?
- - We further overlap the param all-gather with the forward, which means we trigger next part's param all-gather when we are doing the current part's forward.
+In `DistributedOptimizer`, the model gradients and optimizer states are sharded along _Data Parallel_ dimension in each gradient _Bucket_ of _Gradient Buffer_ (see `DDP` for more details), where each DP rank only manages its own shard of gradient, generates its own shard of optimizer states, and updates its own shard of parameters.
-## Compatibility of Optimizer with `DDP`
+The flow of `DistributedOptimizer` is as follows:
-The compatibility of these two optimizers and `DDP` strategy is shown as follows:
+0. During initialization, model parameters are virtually sharded across all DP ranks, such that each DP rank owns a partial view of the original model parameters
+ - This sharding does not respect parameter boundaries, i.e., a parameter could be split into two halves and belong to two DP ranks. Therefore, a complex mapping between the sharded parameters and the original parameters is established, which is mostly done in the `__init__` function. Then the optimizer's `param_groups` is replaced with the _Sharded Parameter_.
-| | `BasicOptimizer` | `DistributedOptimizer` |
-| -------- | ---------------- | ---------------------- |
-| `DDP` | yes | yes |
-| NO `DDP` | yes | no |
+1. Receive _Reduced Gradient_ resulting from `ReduceScatter` per Gradient _Bucket_ in `DDP`
-## Example
+2. Attach _Reduced Gradient_ (`main_grad` of each original parameter) to the _Sharded Parameter_
-### `BasicOptimizer`
+3. Run the actual `optimizer.step()` to generate _Optimizer State_ of each shard and updates _Sharded Parameter_ with _Reduced Gradient_
-See `/test/parallel/ddp_optim/test_ddp.py`.
+4. Copy the updated _Sharded Parameter_ to a specific parameter buffer and get ready for `AllGather` communication to restore the full parameters
-### `DistributedOptimizer`
+ - To avoid the performance overhead and memory cost of per-parameter `AllGather`, the _Gradient Buffer_ of `DDP` is reused as the communication buffer for `AllGather`.
+
+5. Overlap the parameter `AllGather` with the forward computation in the next iteration for hiding communication overhead, similar to gradient `ReduceScater` overlap with backward computation
-A simple usage case is here. For more tests, see `/test/parallel/ddp_optim/test_doptimizer.py`.
+#### How to use it?
```python
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
@@ -63,14 +59,13 @@ mlp = MLP()
# create 2-dim DeviceMesh, the first for data-parallel, while the second for tensor-parallel.
device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]], mesh_dim_names=("DP", "TP"))
-# parallelize torch-native model into TP model (see: `/vescale/dmodule/README.md`)
-tp_mlp = parallelize_module(mlp, device_mesh["TP"], param_and_fwd_sharding_plan)
+# parallelize torch-native model into TP model
+tp_mlp = parallelize_module(mlp, device_mesh["TP"], sharding_plan)
-# wrap TP model with `DDP` (see: `/vescale/ddp/README.md`)
+# wrap TP model with `DDP`
dp_tp_mlp = DDP(
module=tp_mlp,
- data_pg_or_device_mesh=device_mesh["DP"],
- overlap_grad_reduce=False,
+ device_mesh["DP"],
use_distributed_optimizer=True
)
@@ -98,3 +93,34 @@ doptim.zero_grad()
#
```
+
+APIs can found in: `/vescale/optim/distributed_optimizer.py`.
+
+More examples can found in: `/test/parallel/ddp_optim/test_doptimizer.py`.
+
+### veScale `BasicOptimizer`
+
+`BasicOptimizer` is a not ZeRO optimizer but a simple optimizer that works like _Data Parallel_ which replicates parameters, gradients, and optimizer states along _Data Parallel_ dimension.
+
+`BasicOptimizer` itself is nothing but a simple wrapper that wraps given optimizer instance with utilities for veScale `DTensor`, `DModule`, and `DDP`:
+
+- convert between `Tensor` and `DTensor`
+
+- recover flattened gradient from `DDP`
+
+- trigger gradient synchronization of `DModule` (e.g., for Sequence Parallel)
+
+
+APIs can be found in: `/vescale/optim/base_optimizer.py`.
+
+Examples can be found in `/test/parallel/ddp_optim/test_ddp.py`.
+
+
+## How are these optimizers related with `DDP`?
+
+The compatibility of the above optimizers with `DDP` is as follows:
+
+| | `BasicOptimizer` | `DistributedOptimizer` |
+| -------- | ---------------- | ---------------------- |
+| `DDP` | yes | yes |
+| NO `DDP` | yes | no |