Skip to content

Commit

Permalink
[checkpoint] feat: open source fast checkpoint system (#38)
Browse files Browse the repository at this point in the history
## Summary

We improved `vescale.checkpoint` with the following new features for
fast checkpointing (where front three features are built-in techniques
without necessitating manual activation):

- **Saving Plan Caching**: During training, the program may save model
and optimizer checkpoints every n steps. Once a saving plan is created,
it remains unchanged as long as the model does. We implemented plan
caching to avoid regenerating the plan when checkpointing a model or
optimizer multiple times, reducing unnecessary compute and communication
costs. As of 05/30/2024, PyTorch DCP does not support plan caching.

- **Saving Plan Load-Balancing**: In data parallel training, models are
replicated across GPUs with different data parallel ranks but the same
pipeline and tensor parallel ranks. Existing PyTorch DCP (as of
05/30/2024) deduplicates replicated tensors using a simple algorithm,
causing GPUs with data parallel rank 0 to save the entire model, leading
to load imbalance. We implemented a load-balancing algorithm to address
this issue when deduplicating model tensors.

- **D2H Tensor Copying via Pinned Memory**: When copying tensors from
GPU to host memory, `vescale.checkpoint` uses pinned host memory,
reducing memory allocation costs each time a checkpoint is saved. As of
05/30/2024, PyTorch DCP does not support pinned memory.

- **Checkpoint Broadcasting**: In data parallel training, models are
replicated across GPUs with different data parallel ranks but the same
pipeline and tensor parallel ranks. If `broadcast_checkpoint` is
enabled, `vescale.checkpoint.load` lets GPUs with data parallel rank 0
to load the model and broadcast it to other GPUs with higher data
parallel ranks. If GPUs are connected with NCCL and I/O bandwidth is
fully utilized, broadcasting model tensors speeds up checkpoint loading
compared to all GPUs loading models from persistent storage. E.g.:

    ```python
    # prepare checkpoint state for the model and optimizer
checkpoint_state = { "model": distributed_model, "optimizer":
distributed_optimizer }
    # load the checkpoint
vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state,
broadcast_checkpoint=True)
    ```

- **Asynchronous Checkpointing**: When `vescale.checkpoint.save` is
called, it first generates a saving plan and then synchronously copies
tensors from GPU to host memory. If `async_checkpoint` is enabled, the
training program can continue after the D2H copying, while
`vescale.checkpoint.save` continues to serialize tensors and dump the
checkpoint to persistent storage asynchronously without blocking
training. As of 05/30/2024, PyTorch DCP does not support asynchronous
checkpointing. E.g.:

    ```python
    # prepare checkpoint state for the model and optimizer
checkpoint_state = { "model": distributed_model, "optimizer":
distributed_optimizer }
    # save the checkpoint asynchronuously
vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state,
async_checkpoint=True)
    ```

## Acknowledgement

We sincerely appreciate all contributors including but not limited to
@shanesyy-1992 @raywan-110 @lazychao @AHEADer @MingjiHan99
  • Loading branch information
MingjiHan99 authored May 31, 2024
1 parent 55c7f8a commit c4afc72
Show file tree
Hide file tree
Showing 67 changed files with 2,977 additions and 1,044 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ _**An Industrial-Level Framework for Easy-of-Use**_

- 📀 **Automatic Checkpoint Resharding**: veScale manages distributed checkpoints automatically with online resharding across different cluster sizes and different parallelism strategies.

## Latest News

## Coming Soon
- [2024-5-31] veScale's [fast checkpointing system](https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/README.md) open sourced with automatic checkpoint resharding, caching, load-balancing, fast copying, deduplicating, and asynchronous io.

- [2024-5-21] veScale's examples ([Mixtral](https://github.com/volcengine/veScale/tree/main/examples/mixtral_4D_training), [LLama2](https://github.com/volcengine/veScale/tree/main/examples/llama2_4D_finetune), and [nanoGPT](https://github.com/volcengine/veScale/tree/main/examples/nanogpt_4D_finetune)) open sourced with bit-wise correctness of training loss curves.

- [2024-5-13] The debut of veScale in MLSys 2024 as a [poster](https://volcengine.github.io/veScaleWeb/blog/mlsys2024.html).

_**veScale**_ is still in its early phase. We are refactoring our [internal LLM training system](https://arxiv.org/abs/2402.15627) components to meet open source standard. The tentative timeline is as follows:
- [2024-4-16] Our [internal LLM training system](https://volcengine.github.io/veScaleWeb/blog/megascale.html) presented in NSDI 2024.

## Coming Soon

- by end of May, fast checkpointing system
_**veScale**_ is still in its early phase. We are refactoring our internal LLM training system components to meet open source standard. The tentative timeline is as follows:

- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training

Expand Down
3 changes: 1 addition & 2 deletions examples/llama2_4D_finetune/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
################################################################################

import os
import re


def parse_train_loss(log_fn, name=None):
Expand Down Expand Up @@ -57,7 +56,7 @@ def parse(log_fn, name=None):

def run_exps(max_iters, dtypes, run=True):
if not os.path.isfile(TRAIN_BIN_PATH):
os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..")
os.system("cd data/shakespeare/ && python3 prepare.py && cd ../..")
os.makedirs("logs", exist_ok=True)
if run:
for dtype in dtypes:
Expand Down
2 changes: 1 addition & 1 deletion examples/mixtral_4D_training/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_grad_norm(log_fn, name=None):

def run_exps(max_iters, dtypes, run=True):
if not os.path.isfile(TRAIN_BIN_PATH):
os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..")
os.system("cd data/shakespeare/ && python3 prepare.py && cd ../..")
os.makedirs("logs", exist_ok=True)
if run:
for dtype in dtypes:
Expand Down
10 changes: 8 additions & 2 deletions examples/nanogpt_4D_finetune/finetune_4D.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
save_checkpoint_path = "./nanogpt_checkpoint_dir"
load_checkpoint_path = ""
use_dist_dropout = True
async_checkpoint = False
broadcast_checkpoint = False
config = {}


Expand Down Expand Up @@ -349,7 +351,7 @@ def get_lr(it):
# + + + VeScale Load checkpoint
if load_checkpoint_path:
checkpoint_state = {"model": model, "optimizer": optimizer}
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint)
# + + + VeScale API above
# training loop
X, Y = get_batch("train") # fetch the very first batch
Expand Down Expand Up @@ -384,7 +386,11 @@ def get_lr(it):
# Don't save checkpoint
# + + + VeScale API below
checkpoint_state = {"model": model, "optimizer": optimizer}
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
vescale.checkpoint.save(
os.path.join(save_checkpoint_path, f"iter_{iter_num}"),
checkpoint_state,
async_checkpoint=async_checkpoint,
)
# + + + VeScale API above
if iter_num == 0 and eval_only:
break
Expand Down
2 changes: 1 addition & 1 deletion examples/nanogpt_4D_finetune/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def from_pretrained(cls, model_type, override_args=None):
assert all(k == "dropout" for k in override_args)
from transformers import GPT2LMHeadModel

print("loading weights from pretrained gpt: %s" % model_type)
print(f"loading weights from pretrained gpt: {model_type}")

# n_layer, n_head and n_embd are determined from model_type
# + + + add a gpt2-small option for smaller experiments
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ optree
accelerate
transformers==4.37.2
flash_attn
mmh3
2 changes: 1 addition & 1 deletion test/checkpoint/nano_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def from_pretrained(cls, model_type, override_args=None):
assert all(k == "dropout" for k in override_args)
from transformers import GPT2LMHeadModel

print("loading weights from pretrained gpt: %s" % model_type)
print(f"loading weights from pretrained gpt: {model_type}")

# n_layer, n_head and n_embd are determined from model_type
config_args = {
Expand Down
6 changes: 2 additions & 4 deletions test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def test_save(self):
dist_optimizer.step()

# Save the model and optimizer before second data foward

# OmniStore Style API
ckpt_state = {"model": ddp_gpt, "optimizer": dist_optimizer}
vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state)

# Clean up writing futures (For unit test only)
vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup()
# Dump model state_dict
dumped_model_sd = {}
for k, v in ddp_gpt.state_dict().items():
Expand Down Expand Up @@ -108,7 +107,6 @@ def test_load(self):

# Load the model and optimizer after first data

# OmniStore Style API
# One line function, model and optimizer will be loaded automatically
ckpt_state = {"model": ddp_gpt, "optimizer": dist_optimizer}
vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state)
Expand Down
2 changes: 2 additions & 0 deletions test/checkpoint/open_llama/test_open_llama_dp_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def test_open_llama2_with_ddp(self):

ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer}
vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state)
# Clean up writing futures (For unit test only)
vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup()
# For processes with dp_rank = 0, dump model state_dict
if VESCALE_DEVICE_MESH.get_data_parallel_rank() == 0:
dumped_model_sd = {}
Expand Down
2 changes: 2 additions & 0 deletions test/checkpoint/open_llama/test_open_llama_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_open_llama2_with_ddp(self):

ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer}
vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state)
# Clean up writing futures (For unit test only)
vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup()

# Dump model state_dict
dumped_model_sd = {}
Expand Down
2 changes: 2 additions & 0 deletions test/checkpoint/open_llama/test_open_llama_tp_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def test_open_llama2_with_ddp(self):

ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer}
vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state)
# Clean up writing futures (For unit test only)
vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup()

# Merge model state dictionary and save it
# full_tensor contains gather operations
Expand Down
39 changes: 39 additions & 0 deletions test/dmodule/test_fwd_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,5 +805,44 @@ def _test_dict_fwd_plan(self):
self.assert_helper(out, expected_t)


class FwdPlanTestWNestedDictArgs(FwdPlanTestBase):
class DefaultNestedDictArgs(nn.Module):
def forward(self, a: dict = None, b: torch.Tensor = None, *args):
return a["_a"], a["_b"], b

model = DefaultNestedDictArgs

def _test_nested_dict_fwd_plan(self):
fwd_plan = {".input": {"a": {"_a": [Shard(0)], "_b": [Shard(1)]}}}
dmodule = parallelize_module(self.model(), self.device_mesh, {"parameter": {}, "forward": fwd_plan})
_a, _b, b = torch.ones((2, 2)), torch.ones((2, 2)) * 2, torch.ones((2, 2)) * 3
expected_t = [Shard(0), Shard(1), torch.Tensor]

out = dmodule(a={"_a": _a, "_b": _b}, b=b)
self.assert_helper(out, expected_t)


class FwdPlanTestWNestedListArgs(FwdPlanTestBase):
class DefaultNestedListArgs(nn.Module):
def forward(self, a: list, b: torch.Tensor = None, *args):
return a[0], a[1], a[2], b

model = DefaultNestedListArgs

def _test_nested_list_fwd_plan(self):
fwd_plan = {
".input": {
"a": [[Shard(0)], None, None],
"b": [Replicate()],
}
}
dmodule = parallelize_module(self.model(), self.device_mesh, {"parameter": {}, "forward": fwd_plan})
a0, a1, a2, b = torch.ones((2, 2)), torch.ones((2, 2)) * 2, 1, torch.ones((2, 2)) * 3
expected_t = [Shard(0), torch.Tensor, int, Replicate()]

out = dmodule(a=[a0, a1, a2], b=b)
self.assert_helper(out, expected_t)


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion test/dmodule/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _run_parallelize_meta_not_sharded(self, device_type):
def test_initialize_cpu(self):
self._run_parallelize_not_meta_not_sharded("cpu")
self._run_parallelize_not_meta_sharded("cpu")
self._run_parallelize_meta_not_sharded("cpu")
# self._run_parallelize_meta_not_sharded("cpu")

@with_comms_device(device_type="cuda")
def test_initialize_cuda(self):
Expand Down
2 changes: 2 additions & 0 deletions test/dmodule/test_saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict
import tempfile

import unittest
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests
Expand Down Expand Up @@ -113,6 +114,7 @@ def _run_load_model(self, saved_device_type, model_device_type):
self.assertTrue(dtensor.allclose(dmlp(input_tensor), dmlp_golden(input_golden)))

@with_comms_device(device_type="cpu")
@unittest.skip("fail by cuda rng")
def test_cpu(self):
self._run_save("cpu")
self._run_load_model("cpu", "cpu")
Expand Down
7 changes: 7 additions & 0 deletions test/dtensor/general/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def test_equal(self):
dtensor3 = DTensor.from_local(local_tensor3, device_mesh, [Shard(0)])
self.assertTrue(aten.equal(dtensor1, dtensor3) is False)

if self.rank % 2 == 0:
local_tensor4 = torch.ones((2, 8), dtype=torch.float32, device="cuda")
else:
local_tensor4 = torch.zeros((2, 8), dtype=torch.float32, device="cuda")
dtensor4 = DTensor.from_local(local_tensor4, device_mesh, [Shard(0)])
self.assertTrue(aten.equal(dtensor1, dtensor4) is False)

@skip_unless_torch_gpu
@with_comms
def test_local_scalar_dense(self):
Expand Down
1 change: 1 addition & 0 deletions test/dtensor/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# shut up pylint
70 changes: 70 additions & 0 deletions test/dtensor/loss/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
################################################################################
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
################################################################################
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
################################################################################

import itertools
from common_dtensor import (
DTensorTestBase,
with_comms,
)

import torch
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests
from vescale import distribute_tensor
from vescale.dtensor.placement_types import Shard
from vescale.dtensor.loss import loss_parallel


class DistLossParallelTest(DTensorTestBase):
@with_comms
def test_loss_parallel(self):
device_mesh = self.build_device_mesh()

channel_size, channel_dim = 16, 1
test_setup = [
(2, (8, channel_size), (8,)), # calling aten.nll_loss_forward
(3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward
]
weight = torch.rand(channel_size, device=self.device_type)
for input_ndim, input_size, target_size in test_setup:
x = torch.rand(*input_size, device=self.device_type, requires_grad=True)
target = torch.randint(channel_size, target_size, device=self.device_type)

shard_dims = list(range(input_ndim))
reductions = ["none", "mean", "sum"]
for shard_dim, reduction in itertools.product(shard_dims, reductions):
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
y = F.cross_entropy(x, target, weight, reduction=reduction)
with loss_parallel():
if shard_dim == channel_dim:
dist_y = F.cross_entropy(dist_x, target, weight, reduction=reduction)

self.assertTrue(dist_y.placements[0].is_replicate())
self.assertEqual(dist_y.to_local(), y)

if reduction == "none":
y.sum().backward()
dist_y.sum().backward()
else:
y.backward()
dist_y.backward()
self.assertTrue(dist_x.grad.placements[0].is_shard(shard_dim))
self.assertEqual(dist_x.grad.full_tensor(), x.grad)
x.grad.zero_()
else:
with self.assertRaisesRegex(
ValueError,
"loss_parallel",
):
dist_y = F.cross_entropy(dist_x, target, reduction=reduction)


if __name__ == "__main__":
run_tests()
65 changes: 65 additions & 0 deletions test/dtensor/ops/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,49 @@ def test_view_groups(self):
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
),
)
self.assertEqual(
view_groups([2, 0], [0, 2]),
(
Split(Flatten((InputDim(0), InputDim(1))), (0, 2), 0),
Split(Flatten((InputDim(0), InputDim(1))), (0, 2), 1),
),
)
self.assertEqual(
view_groups([1, 0, 0, 1], [0, 1, 3]),
(
Split(Flatten((InputDim(1), InputDim(2))), (0, 3), 0),
Singleton(),
Split(Flatten((InputDim(1), InputDim(2))), (0, 3), 1),
),
)
self.assertEqual(
view_groups([1, 0, 2, 3], [0, 1, 0, 10]),
(
Split(Flatten((InputDim(1), InputDim(2), InputDim(3))), (0, 0, 10), 0),
Singleton(),
Split(Flatten((InputDim(1), InputDim(2), InputDim(3))), (0, 0, 10), 1),
Split(Flatten((InputDim(1), InputDim(2), InputDim(3))), (0, 0, 10), 2),
),
)
self.assertEqual(
view_groups([0, 9, 1], [1, -1]),
(
Singleton(),
Flatten((InputDim(0), InputDim(1))),
),
)
self.assertEqual(
view_groups([1, 0], [0, 0, 1, 3, 1, 0, 10]),
(
Split(InputDim(1), (0, 0, 3, 0, 10), 0),
Split(InputDim(1), (0, 0, 3, 0, 10), 1),
Singleton(),
Split(InputDim(1), (0, 0, 3, 0, 10), 2),
Singleton(),
Split(InputDim(1), (0, 0, 3, 0, 10), 3),
Split(InputDim(1), (0, 0, 3, 0, 10), 4),
),
)
self.assertEqual(
view_groups([3, 4, 5], [12, 5]),
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
Expand Down Expand Up @@ -379,6 +422,17 @@ def test_view_ops(self):
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
)

self.dimmap_test(
torch.reshape,
(randn(8, 12, 0), (8, 12, 1, 0)),
(
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 0),
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 1),
Singleton(),
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 2),
),
)

self.dimmap_test(
torch.tile,
(randn(24, 36), (1, 2, 1, 1, 2)),
Expand Down Expand Up @@ -419,6 +473,17 @@ def test_view_ops(self):
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
)

self.dimmap_test(
Tensor.view,
(randn(8, 12, 0), (8, 12, 1, 0)),
(
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 0),
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 1),
Singleton(),
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 2),
),
)

self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),))

self.dimmap_test(
Expand Down
Loading

0 comments on commit c4afc72

Please sign in to comment.