Skip to content

Commit

Permalink
[gemini] gemini support extra-dp (hpcaitech#5043)
Browse files Browse the repository at this point in the history
* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
  • Loading branch information
flybird11111 authored Nov 16, 2023
1 parent b2ad0d9 commit 3e02154
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 137 deletions.
31 changes: 18 additions & 13 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.distributed.distributed_c10d import _get_default_group

from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
Expand All @@ -34,8 +35,7 @@
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}

DP_AXIS = 0
TP_AXIS = 1
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2

def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
Expand Down Expand Up @@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
Expand Down Expand Up @@ -347,8 +347,8 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
extra_dp_size:int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
Expand Down Expand Up @@ -393,7 +393,7 @@ def __init__(
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_tensor_parallelism = tp_size > 1
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
Expand All @@ -402,12 +402,17 @@ def __init__(
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose

self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.tp_size = tp_size
self.extra_dp_size = extra_dp_size
world_size = dist.get_world_size()
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."

self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
Expand Down Expand Up @@ -458,7 +463,7 @@ def configure(
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)

model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
Expand Down
20 changes: 15 additions & 5 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ class Chunk:
def __init__(
self,
chunk_size: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
extra_dp_group: ProcessGroup = None,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Expand All @@ -76,7 +77,7 @@ def __init__(
Args:
chunk_size (int): the number of elements in the chunk
process_group (ProcessGroup): the process group of this chunk
zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
Expand All @@ -90,9 +91,11 @@ def __init__(
self.chunk_size = chunk_size
self.utilized_size = 0

self.torch_pg = process_group
self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
self.extra_dp_group = extra_dp_group
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1

# the chunk size should be divisible by the dp degree
if not keep_gathered:
Expand Down Expand Up @@ -384,14 +387,20 @@ def reduce(self):
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())

input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)

free_storage(self.cuda_global_chunk)
self.is_gathered = False
Expand Down Expand Up @@ -608,10 +617,11 @@ def init_grad_chunk(self) -> "Chunk":
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
zero_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
extra_dp_group=self.extra_dp_group,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
Expand Down Expand Up @@ -640,4 +650,4 @@ def init_grad_chunk(self) -> "Chunk":
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)

return self.grad_chunk
return self.grad_chunk
10 changes: 6 additions & 4 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def register_tensor(
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
extra_dp_group: ProcessGroup = None,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
Expand Down Expand Up @@ -76,15 +77,16 @@ def register_tensor(

if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
dp_size = dist.get_world_size(process_group)
dp_size = dist.get_world_size(zero_group)
chunk_size = chunk_size + (-chunk_size % dp_size)

chunk = Chunk(
chunk_size=chunk_size,
process_group=process_group,
zero_group=zero_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)

Expand Down Expand Up @@ -288,4 +290,4 @@ def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
# Release accumulated_grad
free_storage(accumulated_grad)

return grad_chunk
return grad_chunk
26 changes: 17 additions & 9 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def __init__(
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
zero_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True,
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
Expand All @@ -105,7 +106,7 @@ def __init__(
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
process_group=zero_group,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
Expand All @@ -128,7 +129,8 @@ def __init__(
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group

self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
Expand Down Expand Up @@ -377,8 +379,12 @@ def grad_handle(self, p, grad):
self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
Expand Down Expand Up @@ -733,7 +739,7 @@ def load_parameter(chunk_slice, data):
unexpected_keys.append(key)

def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
dp_world_size = dist.get_world_size(self.dp_process_group)
zero_world_size = dist.get_world_size(self.zero_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
Expand All @@ -753,8 +759,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
self.chunk_manager.register_tensor(
tensor=p,
group_type="fp16_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
config_key=zero_world_size,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
Expand All @@ -767,8 +774,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
config_key=zero_world_size,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
Expand Down Expand Up @@ -881,4 +889,4 @@ def state_dict_shard(
if block is not None:
yield block, block_size

yield sharder.current_block, sharder.current_block_size
yield sharder.current_block, sharder.current_block_size
24 changes: 16 additions & 8 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import nullcontext
from typing import Optional
import pytest

import torch
import torch.distributed as dist
Expand All @@ -17,14 +18,15 @@
from tests.kit.model_zoo import model_zoo


def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
Expand Down Expand Up @@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso

@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
"""check gemini plugin over model zoo
Args:
Expand Down Expand Up @@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa

# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
enable_tensor_parallelism = False
tp_size = 1

err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
Expand All @@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)

@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)


if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_gemini_plugin(early_stop=False)
Loading

0 comments on commit 3e02154

Please sign in to comment.