Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open source 0725 patch #42

Merged
merged 7 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ _**An Industrial-Level Framework for Easy-of-Use**_

## Latest News

- [2024-7-25] veScale's [pipeline parallelism](https://github.com/volcengine/veScale/blob/main/vescale/pipe/README.md) open sourced with API, graph parser, stage abstraction, schedules and execution runtime along with [nD distributed timeline](https://github.com/volcengine/veScale/blob/main/vescale/ndtimeline/README.md).

- [2024-5-31] veScale's [fast checkpointing system](https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/README.md) open sourced with automatic checkpoint resharding, caching, load-balancing, fast copying, deduplicating, and asynchronous io.

- [2024-5-21] veScale's examples ([Mixtral](https://github.com/volcengine/veScale/tree/main/examples/mixtral_4D_training), [LLama2](https://github.com/volcengine/veScale/tree/main/examples/llama2_4D_finetune), and [nanoGPT](https://github.com/volcengine/veScale/tree/main/examples/nanogpt_4D_finetune)) open sourced with bit-wise correctness of training loss curves.
Expand All @@ -32,7 +34,11 @@ _**An Industrial-Level Framework for Easy-of-Use**_

_**veScale**_ is still in its early phase. We are refactoring our internal LLM training system components to meet open source standard. The tentative timeline is as follows:

- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training
- High-level [nD parallel api](https://github.com/volcengine/veScale/issues/39) for extreme ease of use

- Power-user plan api for easy customization of nD parallel training

- End-to-end vescale/examples with 5D parallel training (TP, SP, DP, ZeRO, PP)

## Table of Content ([web view](https://volcengine.github.io/veScaleWeb/))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/llama_mfu_calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_dp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_load_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_tp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
26 changes: 12 additions & 14 deletions test/parallel/pipeline/api/test_pipe_single_stage_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.testing._internal.common_utils import run_tests
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.plan import PipelineScheduleType, PipelineParallelPlan, ModeType, PipelineSplitMethodType
from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
from vescale.pipe.pipe_stage import construct_pipeline_stage
from vescale.engine import PipeEngine
from common_dtensor import DTensorTestBase, with_comms
from torch.optim import SGD
Expand Down Expand Up @@ -132,9 +132,7 @@ def test_stage_forward(self):
def _run_no_pp_model(self):
os.environ["model_name"] = "golden"
model = EightMLP().to("cuda:0")
optimizer = torch.optim.SGD(
model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False
)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
torch.manual_seed(9999)
batch = [torch.ones(microbatch_size, 128, 32, dtype=torch.float32).to("cuda:0") for _ in range(factor)]
for mb in batch:
Expand Down Expand Up @@ -166,13 +164,6 @@ def _run_stage_forward(self):
mesh_dim_names=("PP", "DP", "TP"),
)

stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
model,
config,
VESCALE_DEVICE_MESH,
update_split_points=True,
)

optimizer_fn_kwargs = {
"lr": 0.01,
"momentum": 0,
Expand All @@ -183,9 +174,16 @@ def _run_stage_forward(self):
"foreach": None,
"differentiable": False,
}
_parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
optimizer = SGD(_parameters, **optimizer_fn_kwargs)
pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)

pipe_module = construct_pipeline_stage(
model,
config,
VESCALE_DEVICE_MESH,
lr_scheduler=None,
update_split_points=True,
)
optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
pipe_module.doptimizer = optimizer

engine = PipeEngine(
pipe_module,
Expand Down
21 changes: 10 additions & 11 deletions test/parallel/pipeline/api/test_schedule_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from common_dtensor import DTensorTestBase, with_comms
from torch.testing._internal.common_utils import run_tests
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
from vescale.pipe.pipe_stage import construct_pipeline_stage
from vescale.pipe._schedules.instruction_base import StageDeps
from vescale.pipe.pipe_emmiter import ScheduleEngine
from vescale.plan.spec import PipelineScheduleType, ModeType, PipelineSplitMethodType
Expand Down Expand Up @@ -79,13 +79,6 @@ def test_simple_1f1b(self):
schedule_type=PipelineScheduleType.SIMPLE_1F1B,
)

stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
model,
config,
VESCALE_DEVICE_MESH,
update_split_points=True,
)

optimizer_fn_kwargs = {
"lr": 0.01,
"momentum": 0,
Expand All @@ -96,9 +89,15 @@ def test_simple_1f1b(self):
"foreach": None,
"differentiable": False,
}
_parameters = list(stage_modules[0].parameters())
optimizer = SGD(_parameters, **optimizer_fn_kwargs)
pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)
pipe_module = construct_pipeline_stage(
model,
config,
VESCALE_DEVICE_MESH,
lr_scheduler=None,
update_split_points=True,
)
optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
pipe_module.doptimizer = optimizer

dep = pipe_module.stage_deps
device_mesh_list = VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes()
Expand Down
1 change: 0 additions & 1 deletion test/parallel/pipeline/api/test_simple_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from vescale.debug.pdb import ForkedPdb
from vescale.optim.base_optimizer import BasicOptimizer
from vescale.pipe.pipe_stage import construct_pipeline_stage
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
Expand Down
2 changes: 1 addition & 1 deletion test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _run_engine_with_1f1b(self, fixed_size=True):
pipe_config,
)

engine.forward_backward(batch)
engine(batch)
optimizer = engine.get_optimizer
optimizer.step()

Expand Down
24 changes: 12 additions & 12 deletions vescale/engine/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
module: PipeModule,
global_mesh: VeDeviceMesh,
loss_fn: Callable,
config: PipelineParallelPlan,
plan: PipelineParallelPlan,
):
"""
Training engine for pipeline parallelism and multi-dimensional
Expand All @@ -46,8 +46,8 @@ def __init__(
training, and optimizer synchronization.
"""
self.module = module
self.virtual_chunks_per_stage = config.virtual_chunks
self.engine_config = config
self.virtual_chunks_per_stage = plan.virtual_chunks
self.engine_plan = plan
self.optimizer = self.module.get_optimizer
self.lr_scheduler = self.module.get_lr_scheduler
self.global_mesh = global_mesh
Expand All @@ -59,16 +59,16 @@ def __init__(
except: # noqa: E722
self.loss_fn = loss_fn
self.schedule_engine = None
self.reuse_comm_shape = self.engine_config.reuse_p2p_tensor_shape
self.reuse_comm_shape = self.engine_plan.reuse_p2p_tensor_shape
if self.reuse_comm_shape:
os.environ["REUSE_COMM_SHAPE"] = "1"
if (
self.engine_config.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
self.engine_plan.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
and self.virtual_chunks_per_stage == 1
):
print("[warning]: #virtual pipeline chunks is 1. Falling back to simple 1F1B schedule.")
self.engine_config.schedule_type = PipelineScheduleType.SIMPLE_1F1B
self.schedule_type = self.engine_config.schedule_type
self.engine_plan.schedule_type = PipelineScheduleType.SIMPLE_1F1B
self.schedule_type = self.engine_plan.schedule_type

def build_schedule(self, minibatches, data_shape=None):
"""
Expand Down Expand Up @@ -105,7 +105,7 @@ def _locate_tp_mesh(_rank):
)
num_minibatches = self._align_num_batches(first_stage_rank, len(minibatches))
# TODO: insert shape inference
batch_p2p_comm = self.engine_config.batch_p2p_comm
batch_p2p_comm = self.engine_plan.batch_p2p_comm
# if on interleaved 1f1b schedule, set batch_p2p_comm to False to execute p2p communication
schedule_type = self.schedule_type
if schedule_type in [PipelineScheduleType.INTERLEAVED_1F1B, PipelineScheduleType.ZERO_BUBBLE]:
Expand All @@ -123,16 +123,16 @@ def _locate_tp_mesh(_rank):
data_iterator=data_iterator,
stage_id=self.global_mesh.get_pipeline_parallel_rank(),
shape=data_shape,
dtype=self.engine_config.p2p_tensor_dtype,
dtype=self.engine_plan.p2p_tensor_dtype,
num_chunks=self.virtual_chunks_per_stage,
input_shapes=None,
input_shapes_unpad=None,
# send_dtypes_map=self.module.recv_dtypes_dict,
overlap_p2p_comm=self.engine_config.overlap_p2p_comm,
overlap_p2p_comm=self.engine_plan.overlap_p2p_comm,
batch_p2p_comm=batch_p2p_comm,
loss_fn=self.loss_fn,
global_mesh=self.global_mesh,
forward_only=self.engine_config.forward_only,
forward_only=self.engine_plan.forward_only,
)

def forward_backward(
Expand Down Expand Up @@ -211,7 +211,7 @@ def parameters(self, including_frozen=False):
def sync_shared_params(self, group_id: int = 0, share_params=True) -> None:
"""
Synchronize gradients and weights among groups of specified units, dictated by
"partition_units" in PipelineConfig. Typically, this function is used for
"partition_units" in PipelineParallelPlan. Typically, this function is used for
synchronizing gradients and weights of embeddings layers in Transformer-based
architecture.
Args:
Expand Down
5 changes: 0 additions & 5 deletions vescale/model/base_gpt/__init__.py

This file was deleted.

Loading