Skip to content

Commit

Permalink
Merge pull request #6065 from duanjunwen/dev/zero_bubble
Browse files Browse the repository at this point in the history
[Feat] Support zero bubble with shardformer input
  • Loading branch information
duanjunwen authored Sep 24, 2024
2 parents 11ae684 + 7e6f793 commit 8501202
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 88 deletions.
18 changes: 15 additions & 3 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
Expand Down Expand Up @@ -282,8 +284,10 @@ def __init__(
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style == "interleaved" or pp_style == "zbv"
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
Expand All @@ -293,7 +297,7 @@ def __init__(
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
Expand All @@ -315,6 +319,14 @@ def __init__(
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
self.schedule = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
overlap_p2p=overlap_p2p,
)
else:
raise NotImplementedError()

Expand Down
38 changes: 38 additions & 0 deletions colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
x.retain_grad()


def require_grad(x: Any) -> None:
"""Call require_grad on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor) and not x.requires_grad:
x.requires_grad_()


def detach(x: Any) -> Any:
"""Call detach() on a tensor.
Expand All @@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
return x


def clone(x: Any) -> Any:
"""Call clone() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The cloned object.
"""
if isinstance(x, torch.Tensor):
return x.clone()
return x


def release_tensor_data(x: Any) -> Any:
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
Args:
x (Any): Object to be called.
Returns:
Any: The deallocate .data object.
"""
if isinstance(x, torch.Tensor):
return x.data.untyped_storage().resize_(0)
return x


def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.
Expand Down
Loading

0 comments on commit 8501202

Please sign in to comment.