Skip to content

Commit

Permalink
Open source 0725 patch (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
MackZackA authored Jul 31, 2024
1 parent aa95bb7 commit 70db7e7
Show file tree
Hide file tree
Showing 29 changed files with 74 additions and 1,620 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ _**An Industrial-Level Framework for Easy-of-Use**_

## Latest News

- [2024-7-25] veScale's [pipeline parallelism](https://github.com/volcengine/veScale/blob/main/vescale/pipe/README.md) open sourced with API, graph parser, stage abstraction, schedules and execution runtime along with [nD distributed timeline](https://github.com/volcengine/veScale/blob/main/vescale/ndtimeline/README.md).

- [2024-5-31] veScale's [fast checkpointing system](https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/README.md) open sourced with automatic checkpoint resharding, caching, load-balancing, fast copying, deduplicating, and asynchronous io.

- [2024-5-21] veScale's examples ([Mixtral](https://github.com/volcengine/veScale/tree/main/examples/mixtral_4D_training), [LLama2](https://github.com/volcengine/veScale/tree/main/examples/llama2_4D_finetune), and [nanoGPT](https://github.com/volcengine/veScale/tree/main/examples/nanogpt_4D_finetune)) open sourced with bit-wise correctness of training loss curves.
Expand All @@ -32,7 +34,11 @@ _**An Industrial-Level Framework for Easy-of-Use**_

_**veScale**_ is still in its early phase. We are refactoring our internal LLM training system components to meet open source standard. The tentative timeline is as follows:

- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training
- High-level [nD parallel api](https://github.com/volcengine/veScale/issues/39) for extreme ease of use

- Power-user plan api for easy customization of nD parallel training

- End-to-end vescale/examples with 5D parallel training (TP, SP, DP, ZeRO, PP)

## Table of Content ([web view](https://volcengine.github.io/veScaleWeb/))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/llama_mfu_calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_dp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_load_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_tp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
26 changes: 12 additions & 14 deletions test/parallel/pipeline/api/test_pipe_single_stage_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.testing._internal.common_utils import run_tests
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.plan import PipelineScheduleType, PipelineParallelPlan, ModeType, PipelineSplitMethodType
from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
from vescale.pipe.pipe_stage import construct_pipeline_stage
from vescale.engine import PipeEngine
from common_dtensor import DTensorTestBase, with_comms
from torch.optim import SGD
Expand Down Expand Up @@ -132,9 +132,7 @@ def test_stage_forward(self):
def _run_no_pp_model(self):
os.environ["model_name"] = "golden"
model = EightMLP().to("cuda:0")
optimizer = torch.optim.SGD(
model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False
)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
torch.manual_seed(9999)
batch = [torch.ones(microbatch_size, 128, 32, dtype=torch.float32).to("cuda:0") for _ in range(factor)]
for mb in batch:
Expand Down Expand Up @@ -166,13 +164,6 @@ def _run_stage_forward(self):
mesh_dim_names=("PP", "DP", "TP"),
)

stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
model,
config,
VESCALE_DEVICE_MESH,
update_split_points=True,
)

optimizer_fn_kwargs = {
"lr": 0.01,
"momentum": 0,
Expand All @@ -183,9 +174,16 @@ def _run_stage_forward(self):
"foreach": None,
"differentiable": False,
}
_parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
optimizer = SGD(_parameters, **optimizer_fn_kwargs)
pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)

pipe_module = construct_pipeline_stage(
model,
config,
VESCALE_DEVICE_MESH,
lr_scheduler=None,
update_split_points=True,
)
optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
pipe_module.doptimizer = optimizer

engine = PipeEngine(
pipe_module,
Expand Down
21 changes: 10 additions & 11 deletions test/parallel/pipeline/api/test_schedule_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from common_dtensor import DTensorTestBase, with_comms
from torch.testing._internal.common_utils import run_tests
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
from vescale.pipe.pipe_stage import construct_pipeline_stage
from vescale.pipe._schedules.instruction_base import StageDeps
from vescale.pipe.pipe_emmiter import ScheduleEngine
from vescale.plan.spec import PipelineScheduleType, ModeType, PipelineSplitMethodType
Expand Down Expand Up @@ -79,13 +79,6 @@ def test_simple_1f1b(self):
schedule_type=PipelineScheduleType.SIMPLE_1F1B,
)

stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
model,
config,
VESCALE_DEVICE_MESH,
update_split_points=True,
)

optimizer_fn_kwargs = {
"lr": 0.01,
"momentum": 0,
Expand All @@ -96,9 +89,15 @@ def test_simple_1f1b(self):
"foreach": None,
"differentiable": False,
}
_parameters = list(stage_modules[0].parameters())
optimizer = SGD(_parameters, **optimizer_fn_kwargs)
pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)
pipe_module = construct_pipeline_stage(
model,
config,
VESCALE_DEVICE_MESH,
lr_scheduler=None,
update_split_points=True,
)
optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
pipe_module.doptimizer = optimizer

dep = pipe_module.stage_deps
device_mesh_list = VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes()
Expand Down
1 change: 0 additions & 1 deletion test/parallel/pipeline/api/test_simple_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from vescale.debug.pdb import ForkedPdb
from vescale.optim.base_optimizer import BasicOptimizer
from vescale.pipe.pipe_stage import construct_pipeline_stage
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
Expand Down
2 changes: 1 addition & 1 deletion test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _run_engine_with_1f1b(self, fixed_size=True):
pipe_config,
)

engine.forward_backward(batch)
engine(batch)
optimizer = engine.get_optimizer
optimizer.step()

Expand Down
24 changes: 12 additions & 12 deletions vescale/engine/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
module: PipeModule,
global_mesh: VeDeviceMesh,
loss_fn: Callable,
config: PipelineParallelPlan,
plan: PipelineParallelPlan,
):
"""
Training engine for pipeline parallelism and multi-dimensional
Expand All @@ -46,8 +46,8 @@ def __init__(
training, and optimizer synchronization.
"""
self.module = module
self.virtual_chunks_per_stage = config.virtual_chunks
self.engine_config = config
self.virtual_chunks_per_stage = plan.virtual_chunks
self.engine_plan = plan
self.optimizer = self.module.get_optimizer
self.lr_scheduler = self.module.get_lr_scheduler
self.global_mesh = global_mesh
Expand All @@ -59,16 +59,16 @@ def __init__(
except: # noqa: E722
self.loss_fn = loss_fn
self.schedule_engine = None
self.reuse_comm_shape = self.engine_config.reuse_p2p_tensor_shape
self.reuse_comm_shape = self.engine_plan.reuse_p2p_tensor_shape
if self.reuse_comm_shape:
os.environ["REUSE_COMM_SHAPE"] = "1"
if (
self.engine_config.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
self.engine_plan.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
and self.virtual_chunks_per_stage == 1
):
print("[warning]: #virtual pipeline chunks is 1. Falling back to simple 1F1B schedule.")
self.engine_config.schedule_type = PipelineScheduleType.SIMPLE_1F1B
self.schedule_type = self.engine_config.schedule_type
self.engine_plan.schedule_type = PipelineScheduleType.SIMPLE_1F1B
self.schedule_type = self.engine_plan.schedule_type

def build_schedule(self, minibatches, data_shape=None):
"""
Expand Down Expand Up @@ -105,7 +105,7 @@ def _locate_tp_mesh(_rank):
)
num_minibatches = self._align_num_batches(first_stage_rank, len(minibatches))
# TODO: insert shape inference
batch_p2p_comm = self.engine_config.batch_p2p_comm
batch_p2p_comm = self.engine_plan.batch_p2p_comm
# if on interleaved 1f1b schedule, set batch_p2p_comm to False to execute p2p communication
schedule_type = self.schedule_type
if schedule_type in [PipelineScheduleType.INTERLEAVED_1F1B, PipelineScheduleType.ZERO_BUBBLE]:
Expand All @@ -123,16 +123,16 @@ def _locate_tp_mesh(_rank):
data_iterator=data_iterator,
stage_id=self.global_mesh.get_pipeline_parallel_rank(),
shape=data_shape,
dtype=self.engine_config.p2p_tensor_dtype,
dtype=self.engine_plan.p2p_tensor_dtype,
num_chunks=self.virtual_chunks_per_stage,
input_shapes=None,
input_shapes_unpad=None,
# send_dtypes_map=self.module.recv_dtypes_dict,
overlap_p2p_comm=self.engine_config.overlap_p2p_comm,
overlap_p2p_comm=self.engine_plan.overlap_p2p_comm,
batch_p2p_comm=batch_p2p_comm,
loss_fn=self.loss_fn,
global_mesh=self.global_mesh,
forward_only=self.engine_config.forward_only,
forward_only=self.engine_plan.forward_only,
)

def forward_backward(
Expand Down Expand Up @@ -211,7 +211,7 @@ def parameters(self, including_frozen=False):
def sync_shared_params(self, group_id: int = 0, share_params=True) -> None:
"""
Synchronize gradients and weights among groups of specified units, dictated by
"partition_units" in PipelineConfig. Typically, this function is used for
"partition_units" in PipelineParallelPlan. Typically, this function is used for
synchronizing gradients and weights of embeddings layers in Transformer-based
architecture.
Args:
Expand Down
5 changes: 0 additions & 5 deletions vescale/model/base_gpt/__init__.py

This file was deleted.

Loading

0 comments on commit 70db7e7

Please sign in to comment.