From 424629fea023a83aa84eacf55afc8007314d9f54 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 16 Aug 2023 15:41:20 +0800 Subject: [PATCH 01/33] [shardformer/sequence parallel] Cherry pick commit to new branch (#4450) * [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code --- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/shardformer/layer/_operation.py | 276 +++++++++++++++++- colossalai/shardformer/layer/linear.py | 23 +- .../shardformer/layer/qkv_fused_linear.py | 27 +- colossalai/shardformer/modeling/gpt2_seq.py | 222 ++++++++++++++ .../shardformer/policies/base_policy.py | 26 +- colossalai/shardformer/policies/gpt2.py | 9 + colossalai/shardformer/shard/shard_config.py | 1 + .../test_gpt2_qkv_fused_linear_1d.py | 34 ++- .../test_layer/test_linear_1d.py | 75 +++-- tests/test_shardformer/test_model/_utils.py | 15 +- .../test_model/test_shard_gpt2.py | 7 + 12 files changed, 655 insertions(+), 65 deletions(-) create mode 100644 colossalai/shardformer/modeling/gpt2_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 28a19af0ce91..3d45a9112fce 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -152,6 +152,7 @@ def __init__( enable_fused_normalization: bool = False, enable_flash_attention: bool = False, enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, num_microbatches: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, @@ -178,6 +179,7 @@ def __init__( self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -195,7 +197,8 @@ def __init__( enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused) + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01b33..13e563123d28 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch.distributed as dist import torch.nn.functional as F @@ -141,6 +143,215 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + # create new stream for calculate the gradient + calculate_stream = torch.cuda.Stream() + + # do all gather in default stream + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + + # calculate gradient in calculate_stream + with torch.cuda.stream(calculate_stream): + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + + torch.cuda.current_stream().wait_stream(calculate_stream) + + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + with torch.cuda.stream(calculate_stream): + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + print(grad_output.shape, input_parallel.shape) + grad_weight = grad_output.t().matmul(input_parallel) + + torch.cuda.current_stream().wait_stream(calculate_stream) + + return output, grad_weight, grad_bias, None, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + return output, grad_weight, grad_bias, None, None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -200,6 +411,26 @@ def backward(ctx, grad_output): return _reduce(grad_output, ctx.process_group), None +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None): return input_ # all gather + input_ = input_.contiguous() rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ @@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None): return output -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. +def _reduce_scatter(input_, dim=1, process_group=None): + """ Do reduce-scatter operation. Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ - @staticmethod - def forward(ctx, input_, dim, process_group): - ctx.process_group = process_group - ctx.dim = dim - return _gather(input_, dim, process_group) + # reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return output def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): + return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim, overlap) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d59b68ce4480..69ac3ad2581a 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,8 @@ from ._operation import ( gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, split_forward_gather_backward, @@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule): gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -69,6 +73,8 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -80,6 +86,8 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -180,7 +188,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1, self.overlap) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -221,6 +235,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -238,6 +253,7 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -373,7 +389,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index df942d43ee2d..ccb2bf7ea4cc 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,7 +25,9 @@ from ._operation import ( gather_forward_split_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, + matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -173,6 +176,7 @@ def __init__(self, process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, + seq_parallel: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -185,6 +189,7 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -296,15 +301,19 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) - # input_parallel = input_ # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + if self.seq_parallel: + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1) + else: + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -329,6 +338,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -346,6 +356,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -363,6 +374,7 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -499,7 +511,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py new file mode 100644 index 000000000000..a6da96e7bf73 --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2_seq.py @@ -0,0 +1,222 @@ +# this code is modified from transformers.models.gpt2.modeling_gpt2 +# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670 + +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import logging + +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +# TODO: put all contents in `gpt2.py` and make it compatible with pipeline +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 69493bfb6007..7022a1cfd7a2 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -11,17 +11,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager +from ..layer.parallel_module import ParallelModule from ..shard.shard_config import ShardConfig __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] -class ParallelModule(): - - def __init__(self): - pass - - @dataclass class SubModuleReplacementDescription: r""" @@ -231,3 +226,22 @@ def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] + + def append_seq_parallel_to_policy( + self, + suffix_list: List[str], + module_policy_description: ModulePolicyDescription, + ): + r""" + Append the sequence parallel policy to the policy for the given key. + + Args: + suffix_list (List[str]): the suffix list of the module to be parallelized + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + """ + + for sub_description in module_policy_description.sub_module_replacement: + if (sub_description.suffix in suffix_list): + if sub_description.kwargs is None: + sub_description.kwargs = {} + sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 20e5fa372c8f..276d95660c4d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -7,6 +7,7 @@ from .._utils import getattr_, setattr_ from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward +from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -49,6 +50,9 @@ def module_policy(self): target_module=col_nn.DropoutForParallelInput, ), ]) + if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -120,6 +124,11 @@ def module_policy(self): policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ 'forward': get_gpt2_flash_attention_forward(), }) + + if self.shard_config.enable_sequence_parallelism: + suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] + self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) + return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0c28f115d018..a36e878c623f 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -28,6 +28,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False + enable_sequence_parallelism: bool = False # pipeline_parallel_size: int # data_parallel_size: int diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index b45cd172c3ca..ae6a1dc90dc5 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_col(lazy_init: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool): linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, + seq_parallel=seq_parallel, n_fused=3) assert linear.weight.shape == torch.Size([48, 192]) @@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool): linear.load_state_dict(linear_conv_col.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) - gather_out = linear_conv_col(x) - assert_close(rearrange(out, 1), gather_out) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + gather_out = linear_conv_col(x_for_shard) + assert_close(rearrange(out, -1), gather_out) # check backward correctness out.sum().backward() @@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_row(lazy_init: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool): linear.load_state_dict(linear_row.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool): assert_close(target_grad, linear_row.weight.grad) +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel) + check_linear_conv_1d_row(lazy_init, seq_parallel) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_linear_conv_1d_col() - check_linear_conv_1d_row() + check_gpt2_qkv_fused_linear_1d() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index aa75879e0313..3ad8f14b99e6 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -12,13 +12,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) -def check_linear_1d_col(lazy_init: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) + linear_col = Linear1D_Col.from_native_module(linear_copy, + process_group=None, + gather_output=True, + seq_parallel=seq_parallel, + overlap=overlap) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool): linear_col.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_1d_row(lazy_init: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = Linear1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool): linear_row.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) x_for_shard = x.expand_as(x.clone()) @@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_col_plus_row(lazy_init: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool): with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + linear_col = Linear1D_Col.from_native_module(linear_1_copy, + process_group=None, + gather_output=False, + seq_parallel=seq_parallel, + overlap=overlap) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, + process_group=None, + parallel_input=True, + seq_parallel=seq_parallel) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool): linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - assert_close(unshard_out, shard_out) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) # check backward correctness unshard_out.sum().backward() @@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) + + +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +@parameterize('overlap', [False, True]) +def run_dist_linear_test(lazy_init, seq_parallel, overlap): + check_linear_1d_col(lazy_init, seq_parallel, overlap) + check_linear_1d_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel, overlap) -def run_dist(rank, world_size, port): +def check_dist_linear(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_linear_1d_col() - check_linear_1d_row() - check_linear_col_plus_row() + run_dist_linear_test() @rerun_if_address_is_in_use() def test_linear(): - spawn(run_dist, nprocs=2) + spawn(check_dist_linear, nprocs=2) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 921af2a8b1d0..7e1e6f2fe03a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,4 +1,5 @@ import copy +import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -25,6 +26,7 @@ def build_model(model_fn, enable_tensor_parallelism=True, enable_flash_attention=False, enable_jit_fused=False, + enable_sequence_parallelism=False, use_lazy_init: bool = False): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() @@ -38,7 +40,8 @@ def build_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) @@ -135,6 +138,16 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() + + if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + seq_len = data['input_ids'].shape[1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data['input_ids'].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat(1, times) + sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in data.items(): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ca086bf12776..c97702cbb281 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -106,6 +106,13 @@ def unwrap(module): 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': False, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_gpt2_test(test_config): From 6ef33f75aa05390894e411296acf8db8a0b55118 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 16 Aug 2023 16:11:57 +0800 Subject: [PATCH 02/33] [shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446) * support DDP for HybridPlugin/add tp+dp tests * add docstring for HybridParallelPlugin --- .../booster/plugin/hybrid_parallel_plugin.py | 129 ++++++++++++++---- tests/test_shardformer/test_model/_utils.py | 13 ++ .../test_model/test_shard_bert.py | 17 ++- .../test_model/test_shard_bloom.py | 19 +-- .../test_model/test_shard_chatglm.py | 19 +-- .../test_model/test_shard_gpt2.py | 21 ++- .../test_model/test_shard_llama.py | 20 +-- .../test_model/test_shard_opt.py | 20 +-- .../test_model/test_shard_t5.py | 20 +-- .../test_model/test_shard_vit.py | 21 +-- 10 files changed, 199 insertions(+), 100 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3d45a9112fce..00c714fe4612 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -6,7 +6,8 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.nn import Module +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -28,7 +29,8 @@ class HybridParallelModule(ModelWrapper): - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, + ddp_config: dict) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) @@ -45,7 +47,15 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp module = module.to(dtype=torch.bfloat16).cuda() else: module = module.cuda() # train without AMP - # TODO(ver217): support TP+DP + + if use_ddp: + + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + super().__init__(module) def sync_shared_params(self): @@ -68,6 +78,12 @@ def sync_grads(self): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) @@ -140,29 +156,81 @@ def __init__( class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. + bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. + """ + + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + cpu_offload: bool = False, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + num_microbatches: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers=True, + bucket_cap_mb=25, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False, + static_graph=False) -> None: - def __init__( - self, - tp_size: int, - pp_size: int, - precision: str = 'fp16', - zero_stage: int = 0, - cpu_offload: bool = False, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - num_microbatches: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - ) -> None: super().__init__() assert dist.get_world_size() % ( tp_size * pp_size @@ -208,6 +276,13 @@ def __init__( min_scale=min_scale, max_scale=max_scale, ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) self.max_norm = max_norm @property @@ -241,7 +316,9 @@ def configure( lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7e1e6f2fe03a..789b3b24e696 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -13,6 +13,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -259,3 +260,15 @@ def check_grad(org_model: Module, assert torch.allclose( org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + + +def unwrap_model(module: Module, + base_model_class_name: Optional[str] = None, + base_model_attribute_name: Optional[str] = None): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if base_model_class_name is None: + return module + if module.__class__.__name__ == base_model_class_name: + return module + return getattr(module, base_model_attribute_name, None) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0a24e46d28f2..49de9cc0311c 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -15,6 +15,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -44,13 +45,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'BertModel': - bert = org_model - sharded_bert = sharded_model.unwrap() - else: - bert = org_model.bert - sharded_bert = sharded_model.unwrap().bert + + bert = unwrap_model(org_model, 'BertModel', 'bert') + sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] @@ -98,6 +95,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index ed0d1d8e401d..af014a8585b5 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -13,6 +13,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -46,12 +47,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'BloomModel': - bloom = org_model - sharded_bloom = sharded_model.unwrap() - else: - bloom = org_model.transformer - sharded_bloom = sharded_model.unwrap().transformer + bloom = unwrap_model(org_model, 'BloomModel', 'transformer') + sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] @@ -97,12 +94,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_bloom_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index bb77759048b3..210f775b540d 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -14,6 +14,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'ChatGLMModel': - chatglm_model = org_model - shard_chatglm_model = sharded_model.unwrap() - else: - chatglm_model = org_model.transformer - shard_chatglm_model = sharded_model.unwrap().transformer + chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') + shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] @@ -121,12 +118,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_chatglm_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c97702cbb281..97295f72f4e1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,7 +3,6 @@ from torch import distributed as dist import colossalai -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -15,6 +14,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,16 +48,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - def unwrap(module): - if isinstance(module, HybridParallelModule): - module = module.unwrap() - if module.__class__.__name__ == 'GPT2Model': - return module - return module.transformer - # unwrap model - gpt2 = unwrap(org_model) - sharded_gpt2 = unwrap(sharded_model) + gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') + sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] @@ -106,6 +99,12 @@ def unwrap(module): 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, @@ -117,8 +116,6 @@ def unwrap(module): @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 30ebdfbe5cd9..a433567b3702 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -16,6 +16,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -52,12 +53,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'LlamaModel': - llama_model = org_model - shard_llama_model = sharded_model.unwrap() - else: - llama_model = org_model.model - shard_llama_model = sharded_model.unwrap().model + llama_model = unwrap_model(org_model, 'LlamaModel', 'model') + shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] @@ -128,13 +125,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_llama_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8d1154d82638..2fb14903b6a9 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -16,6 +16,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -51,12 +52,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'OPTModel': - opt_model = org_model - shard_opt_model = sharded_model.unwrap() - else: - opt_model = org_model.model - shard_opt_model = sharded_model.unwrap().model + opt_model = unwrap_model(org_model, 'OPTModel', 'model') + shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' @@ -123,14 +120,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_opt_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 066f7ee815b4..234ce812a08c 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,5 +1,6 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers @@ -14,6 +15,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,8 +50,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - t5 = org_model - sharded_t5 = sharded_model.unwrap() + t5 = unwrap_model(org_model) + sharded_t5 = unwrap_model(sharded_model) row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] @@ -99,17 +101,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) @clear_cache_before_run() def run_t5_test(test_config): - # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO(baizhou): add test_config for flash attention & jit operator after supporting - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 18df8ef555f2..b9d303841215 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -14,6 +14,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'ViTModel': - vit_model = org_model - shard_vit_model = sharded_model.unwrap() - else: - vit_model = org_model.vit - shard_vit_model = sharded_model.unwrap().vit + vit_model = unwrap_model(org_model, 'ViTModel', 'vit') + shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] @@ -120,15 +117,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_vit_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 26e29d58f0525ff573d6a2eeae328e0a4d7f9a68 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 16 Aug 2023 18:56:52 +0800 Subject: [PATCH 03/33] [devops] add large-scale distributed test marker (#4452) * [test] remove cpu marker * [test] remove gpu marker * [test] update pytest markers * [ci] update unit test ci --- .github/workflows/build_on_pr.yml | 2 +- .../compatiblity_test_on_dispatch.yml | 2 +- .github/workflows/compatiblity_test_on_pr.yml | 2 +- .../compatiblity_test_on_schedule.yml | 2 +- applications/Chat/tests/test_dataset.py | 79 ++++++------- applications/Chat/tests/test_models.py | 105 +++++++----------- pytest.ini | 6 +- tests/test_config/test_load_config.py | 1 - tests/test_context/test_hybrid_parallel.py | 1 - tests/test_data/test_cifar10_dataset.py | 3 +- tests/test_data/test_data_parallel_sampler.py | 1 - .../test_deterministic_dataloader.py | 1 - .../test_activation_checkpointing.py | 1 - 13 files changed, 81 insertions(+), 125 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 8a1bc8e113de..4c7e08e5799e 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 1778d64ee287..63c0fbbb975d 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -44,7 +44,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index c0f45c65a7fc..c9f84806be30 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -35,7 +35,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 15ac4f1a92bb..3f8fc96395c9 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -32,7 +32,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py index 64ea1178cd0d..1d9aa50e2c8f 100644 --- a/applications/Chat/tests/test_dataset.py +++ b/applications/Chat/tests/test_dataset.py @@ -14,29 +14,43 @@ SFT_DATASET = [ { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 + "instruction": + "Provide a list of the top 10 most popular mobile games in Asia", + "input": + "", + "output": + "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": + 0 }, { - "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", - "input": "", - "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", - "id": 1 + "instruction": + "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": + "", + "output": + "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": + 1 }, { - "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", - "input": "", - "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", - "id": 2 + "instruction": + "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": + "", + "output": + "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": + 2 }, ] PROMPT_DATASET = [ { - "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": 0 + "instruction": + "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", + "id": + 0 }, { "instruction": "Write a descriptive paragraph about a memorable vacation you went on", @@ -71,9 +85,7 @@ def make_tokenizer(model: str): return tokenizer -def check_content(input_ids_stripped: torch.Tensor, - tokenizer: PreTrainedTokenizer, - model: str): +def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str): if model == "opt": # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt. assert input_ids_stripped[0] == tokenizer.eos_token_id @@ -90,13 +102,10 @@ def check_content(input_ids_stripped: torch.Tensor, assert input_ids_stripped != tokenizer.mask_token_id -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_datasets_size", [2]) -def test_prompt_dataset(model: str, - max_datasets_size: int, - max_length: int): +def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): with tempfile.TemporaryDirectory() as tmp_dir: dataset_name = "prompt_dataset.json" with open(os.path.join(tmp_dir, dataset_name), "w") as f: @@ -119,19 +128,12 @@ def test_prompt_dataset(model: str, check_content(input_ids.masked_select(attention_mask), tokenizer, model) -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize(["dataset_path", "subset"], [ - ("Anthropic/hh-rlhf", "harmless-base"), - ("Dahoas/rm-static", None) -]) +@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), + ("Dahoas/rm-static", None)]) @pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_length", [32, 1024]) -def test_reward_dataset(model: str, - dataset_path: str, - subset: Optional[str], - max_datasets_size: int, - max_length: int): +def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): data = load_dataset(dataset_path, data_dir=subset) assert max_datasets_size <= len(data["train"]) \ and max_datasets_size <= len(data["test"]) @@ -188,15 +190,11 @@ def test_reward_dataset(model: str, assert torch.all(r_mask) -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_length", [32, 1024]) -def test_sft_dataset(model: str, - dataset_path: Optional[str], - max_dataset_size: int, - max_length: int): +def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int): tokenizer = make_tokenizer(model) if dataset_path == "yizhongw/self_instruct": data = load_dataset(dataset_path, "super_natural_instructions") @@ -232,10 +230,7 @@ def test_sft_dataset(model: str, if __name__ == "__main__": - test_sft_dataset(model="bloom", - dataset_path="yizhongw/self_instruct", - max_dataset_size=2, - max_length=256) + test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) test_reward_dataset(model="gpt2", dataset_path="Anthropic/hh-rlhf", @@ -243,6 +238,4 @@ def test_sft_dataset(model: str, max_datasets_size=8, max_length=256) - test_prompt_dataset(model="opt", - max_datasets_size=2, - max_length=128) + test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py index bd6b3e8a5ad1..e96ff8bd7aa7 100644 --- a/applications/Chat/tests/test_models.py +++ b/applications/Chat/tests/test_models.py @@ -15,16 +15,17 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean -@pytest.mark.gpu @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [32]) -@pytest.mark.parametrize("actor_maker", [ - lambda: BLOOMActor(), - lambda: GPTActor(), +@pytest.mark.parametrize( + "actor_maker", + [ + lambda: BLOOMActor(), + lambda: GPTActor(), # HACK: skip llama due to long execution time # lambda: LlamaActor(), - lambda: OPTActor() -]) + lambda: OPTActor() + ]) @pytest.mark.parametrize("generate_kwargs", [{ "max_length": 64, "use_cache": True, @@ -32,23 +33,15 @@ "temperature": 1.0, "top_k": 50, }]) -def test_generation(actor_maker: Callable[[], Actor], - batch_size: int, - seq_len: int, - generate_kwargs: Dict[str, Any] - ): +def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() sequences = generate(actor.cuda(), input_ids, **generate_kwargs) assert sequences.shape == (batch_size, generate_kwargs["max_length"]) -@pytest.mark.cpu def test_utils(): - fn_input = { - "tensor": torch.ones((10, )), - "mask": torch.randint(0, 2, (10, )) - } + fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))} fn_output = masked_mean(dim=0, **fn_input) assert fn_output.dim() == 0 assert torch.allclose(fn_output, torch.tensor(1.0)) @@ -56,14 +49,14 @@ def test_utils(): batch_size = 4 num_labels = 10 fn_input = { - "r": torch.ones((batch_size, )), + "r": torch.ones((batch_size,)), "kl_coef": 1.0, "log_probs": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)), "action_mask": torch.randint(0, 2, (batch_size, num_labels)) } fn_output = compute_reward(**fn_input) - assert fn_output.shape == (batch_size, ) + assert fn_output.shape == (batch_size,) batch_size = 4 seq_len = 32 @@ -80,17 +73,11 @@ def test_utils(): assert fn_output.shape == (batch_size, num_actions) -@pytest.mark.cpu @pytest.mark.parametrize("lora_rank", [4]) @pytest.mark.parametrize("num_dim", [32]) @pytest.mark.parametrize("num_layers", [4]) -def test_lora(lora_rank: int, - num_dim: int, - num_layers: int): - model = nn.ModuleList( - [nn.Linear(num_dim, num_dim) - for _ in range(num_layers)] - ) +def test_lora(lora_rank: int, num_dim: int, num_layers: int): + model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)]) lora_model = convert_to_lora_module(model, lora_rank) assert isinstance(lora_model, nn.ModuleList) for i in range(num_layers): @@ -103,8 +90,7 @@ def test_lora(lora_rank: int, assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, - lora_model[i].lora_B @ lora_model[i].lora_A) + assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) optimizer = torch.optim.Adam(lora_model.parameters()) x = torch.randn(8, num_dim) for i in range(num_layers): @@ -120,20 +106,19 @@ def test_lora(lora_rank: int, lora_model[i].lora_B @ lora_model[i].lora_A) -@pytest.mark.cpu @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [128]) -@pytest.mark.parametrize("models_maker", [ - lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), - lambda: (GPTActor(), GPTCritic(), GPTRM()), +@pytest.mark.parametrize( + "models_maker", + [ + lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), + lambda: (GPTActor(), GPTCritic(), GPTRM()), # HACK: skip llama due to long execution time # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), - lambda: (OPTActor(), OPTCritic(), OPTRM()), -]) + lambda: (OPTActor(), OPTCritic(), OPTRM()), + ]) @torch.no_grad() -def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], - batch_size: int, - seq_len: int): +def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): actor_input = { "input_ids": torch.randint(0, 100, (batch_size, seq_len)), @@ -162,17 +147,14 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], rm_output = rm(**rm_input) assert actor_output.logits.shape[:2] == (batch_size, seq_len) - assert critic_output.shape == (batch_size, ) - assert rm_output.shape == (batch_size, ) + assert critic_output.shape == (batch_size,) + assert rm_output.shape == (batch_size,) -@pytest.mark.cpu @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("num_labels", [100]) -def test_loss(batch_size: int, - seq_len: int, - num_labels: int): +def test_loss(batch_size: int, seq_len: int, num_labels: int): loss = GPTLMLoss() loss_input = { "logits": torch.randn(batch_size, seq_len, num_labels), @@ -182,54 +164,43 @@ def test_loss(batch_size: int, loss = PolicyLoss() loss_input = { - "log_probs": torch.randn(batch_size, ), - "old_log_probs": torch.randn(batch_size, ), - "advantages": torch.randn(batch_size, ) + "log_probs": torch.randn(batch_size,), + "old_log_probs": torch.randn(batch_size,), + "advantages": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = ValueLoss() loss_input = { - "values": torch.randn(batch_size, ), - "old_values": torch.randn(batch_size, ), - "reward": torch.randn(batch_size, ) + "values": torch.randn(batch_size,), + "old_values": torch.randn(batch_size,), + "reward": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = LogSigLoss() loss_input = { - "chosen_reward": torch.randn(batch_size, ), - "reject_reward": torch.randn(batch_size, ), + "chosen_reward": torch.randn(batch_size,), + "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) loss = LogExpLoss() loss_input = { - "chosen_reward": torch.randn(batch_size, ), - "reject_reward": torch.randn(batch_size, ), + "chosen_reward": torch.randn(batch_size,), + "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) if __name__ == "__main__": - generate_kwargs = dict(max_length=40, - use_cache=True, - do_sample=True, - temperature=1.0, - top_k=50) - test_generation(lambda: LlamaActor(), - batch_size=4, - seq_len=32, - generate_kwargs=generate_kwargs) + generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50) + test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs) test_utils() test_lora(lora_rank=2, num_dim=8, num_layers=2) - test_models(models_maker=lambda: (BLOOMActor(), - BLOOMCritic(), - BLOOMRM()), - batch_size=8, - seq_len=128) + test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/pytest.ini b/pytest.ini index e8a60c85336b..7912dbffc6ef 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,5 @@ [pytest] markers = - cpu: tests which can run on CPU - gpu: tests which requires a single GPU - dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features + dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) + largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 550af2a4ae81..38b5e3f5f4fc 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -8,7 +8,6 @@ from colossalai.context.config import Config -@pytest.mark.cpu def test_load_config(): filename = Path(__file__).parent.joinpath('sample_config.py') config = Config.from_file(filename) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index 9f26a5af53ce..d25668afd430 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -143,7 +143,6 @@ def run_dist(rank, world_size, port, backend, port_list, host): reset_seeds() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_context(): """ diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py index 4b9ca61d9f17..dfa9fa211ef0 100644 --- a/tests/test_data/test_cifar10_dataset.py +++ b/tests/test_data/test_cifar10_dataset.py @@ -5,11 +5,10 @@ from pathlib import Path import pytest -from torchvision import transforms, datasets from torch.utils.data import DataLoader +from torchvision import datasets, transforms -@pytest.mark.cpu def test_cifar10_dataset(): # build transform transform_pipeline = [transforms.ToTensor()] diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index 2ad3fd696c39..7beef707c096 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -53,7 +53,6 @@ def run_data_sampler(rank, world_size, port): torch.cuda.empty_cache() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): spawn(run_data_sampler, 4) diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index 239e79dff7d8..283b5cc35279 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -64,7 +64,6 @@ def run_data_sampler(rank, world_size, port): torch.cuda.empty_cache() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): spawn(run_data_sampler, 4) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 2930552cc4e7..b7764c2f4371 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -40,7 +40,6 @@ def forward_inplace(x, weight): return out -@pytest.mark.gpu @clear_cache_before_run() @parameterize("use_reentrant", [True, False]) @parameterize("cpu_offload", [True, False]) From a78daf6180cec55b37713418cad8f406f57939e8 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Wed, 16 Aug 2023 19:29:03 +0800 Subject: [PATCH 04/33] [shardformer] support interleaved pipeline (#4448) * support interleaved pipeline * fix unit test * remove virtual stage test in stage mgr * add droped type hint and updated bwd --- colossalai/cluster/process_group_mesh.py | 10 +- colossalai/pipeline/p2p.py | 45 +-- .../pipeline/schedule/interleaved_pp.py | 370 ++++++++++++++++++ colossalai/pipeline/schedule/one_f_one_b.py | 78 +++- colossalai/pipeline/stage_manager.py | 78 +--- .../test_schedule/test_interleaved.py | 161 ++++++++ tests/test_pipeline/test_stage_manager.py | 9 - 7 files changed, 642 insertions(+), 109 deletions(-) create mode 100644 colossalai/pipeline/schedule/interleaved_pp.py create mode 100644 tests/test_pipeline/test_schedule/test_interleaved.py diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1dfd261d5d01..623160003767 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -94,17 +94,23 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: return np.unravel_index(rank, shape) @staticmethod - def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: """Convert a coordinate to a rank. + mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. + with wrap, index out of range would be wrapped around. + For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2) Args: coords (Tuple[int, ...]): Coordinate to be converted. shape (Tuple[int, ...]): Shape of the process group mesh. + mode (Optional[str]): The mode for numpy.ravel_multi_index. Returns: int: Rank of the coordinate. """ - return np.ravel_multi_index(coord, shape) + + assert mode in ["raise", "wrap", "clip"] + return np.ravel_multi_index(coord, shape, mode) def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index af7a00b5c720..aed85cf91512 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -173,14 +173,10 @@ def recv_forward(self, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(): - input_tensor = None - else: - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object(prev_rank, cur_rank, - self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) return input_tensor @@ -193,14 +189,11 @@ def recv_backward(self, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(): - output_tensor_grad = None - else: - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object(next_rank, cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + output_tensor_grad = _recv_object(next_rank, cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) return output_tensor_grad @@ -211,12 +204,10 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_last_stage(): - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(output_object, cur_rank, next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -225,9 +216,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.stage_manager.is_first_stage(): - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(input_object, cur_rank, prev_rank, - self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py new file mode 100644 index 000000000000..35a33491b03c --- /dev/null +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -0,0 +1,370 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from .base import PipelineSchedule + + +class InterleavedSchedule(PipelineSchedule): + + def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: + self.num_model_chunks = num_model_chunks + assert num_microbatches % self.num_model_chunks == 0, \ + "Number of microbatches should be an integer multiple of number of model chunks" + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.microbatch_size: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not forward: + model_chunk_id = (self.num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def is_first_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the first stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the first stage. + """ + if self.stage_manager.is_first_stage() and model_chunk_id == 0: + return True + return False + + def is_last_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the last stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the last stage. + """ + if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + return True + return False + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.is_first_stage(model_chunk_id): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.is_last_stage(model_chunk_id): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.is_last_stage(model_chunk_id): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.is_first_stage(model_chunk_id): + self.comm.send_backward(input_object, prev_rank) + + def forward_step(self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model (Module): Model Chunk to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + + if self.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step(self, + model_chunk: Module, + optimizer: OptimizerWrapper, + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False) -> dict: + """Runs interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model_chunk (List[Module]): Model Chunk to be trained. + optimizer (OptimizerWrapper): Optimizer to be used. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + + self.load_batch(data_iter) + num_model_chunks = len(model_chunk) + + # num_warmup_microbatches is the step when not all the processes are working + num_microbatches = self.num_microbatches * num_model_chunks + if forward_only: + num_warmup_microbatches = num_microbatches + else: + num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [[] for _ in range(num_model_chunks)] + output_objs = [[] for _ in range(num_model_chunks)] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # for ranks except the first one, get into recv state + # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) + input_obj = self.recv_forward(0) + input_objs[0].append(input_obj) + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=True) + + # recv first on first rank to avoid sending or recving at the same time + if self.stage_manager.is_first_stage(): + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + self.send_forward(model_chunk_id, output_obj) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + else: + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not forward_only: + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) + if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: + break + else: + model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) + + input_obj = self.recv_forward(model_chunk_id) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.recv_forward(model_chunk_id) + + else: + self.send_forward(model_chunk_id, output_obj) + # Add input_obj and output_obj to end of list. + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + + model_chunk_id = self.get_model_chunk_id(i, forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + # backward + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_microbatches_remaining, num_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=False) + # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) + + if outputs is not None: + outputs = merge_batch(outputs) + return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ade3cf456fe3..f5e4929aa7c8 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -53,6 +53,62 @@ def load_micro_batch(self) -> Any: self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For 1F1B. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.stage_manager.is_first_stage(): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For 1F1B. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.stage_manager.is_last_stage(): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank) + def forward_step(self, model: Module, input_obj: Optional[dict], @@ -171,11 +227,11 @@ def forward_backward_step(self, # Run warmup forward passes. for i in range(num_warmup_microbatches): - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not forward_only: input_objs.append(input_obj) @@ -185,7 +241,7 @@ def forward_backward_step(self, # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() # Run 1F1B in steady state. for i in range(num_microbatches_remaining): @@ -193,15 +249,15 @@ def forward_backward_step(self, output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not last_iteration: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() else: # TODO adjust here - self.comm.send_forward(output_obj) - output_obj_grad = self.comm.recv_backward() + self.send_forward(output_obj) + output_obj_grad = self.recv_backward() # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -216,8 +272,8 @@ def forward_backward_step(self, if last_iteration: input_obj = None else: - input_obj = self.comm.recv_forward() - self.comm.send_backward(input_obj_grad) + input_obj = self.recv_forward() + self.send_backward(input_obj_grad) # Run cooldown backward passes. if not forward_only: @@ -225,9 +281,9 @@ def forward_backward_step(self, input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - output_obj_grad = self.comm.recv_backward() + output_obj_grad = self.recv_backward() input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.comm.send_backward(input_obj_grad) + self.send_backward(input_obj_grad) if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index fe228e2270dd..6ba7dc629958 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -17,28 +17,24 @@ class PipelineStageManager: Attributes: num_stages (int): Number of stages in the pipeline. stage (int): The current stage. - num_virtual_stages (int): Number of virtual stages in the pipeline. - virtual_stage (int): The current virtual stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis - self.num_virtual_stages: Optional[int] = None - self.virtual_stage: Optional[int] = None self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} # init prev and next coord coord = self.pg_mesh.coordinate() - if self.stage > 0: - prev_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] - self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) - if self.stage < self.num_stages - 1: - next_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] - self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) + # the prev rank of rank0 is the last rank + prev_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap') + # the next rank of the last rank is rank0 + next_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap') # init p2p process groups stages = list(range(self.num_stages)) @@ -48,32 +44,28 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self, virtual: bool = False) -> bool: - """Is the current stage the first stage. + if is_virtual: + # add the process group of the first rank and the last rank + # only used in interleaved pipeline for now + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) + if self.stage in [stages[0], stages[-1]]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + def is_first_stage(self) -> bool: + """Is the current stage the first stage. Returns: bool: Whether the current stage is the first stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == 0 return self.stage == 0 - def is_last_stage(self, virtual: bool = False) -> bool: + def is_last_stage(self) -> bool: """Is the current stage the last stage. - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. - Returns: bool: Whether the current stage is the last stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == self.num_virtual_stages - 1 return self.stage == self.num_stages - 1 @property @@ -108,7 +100,6 @@ def get_prev_rank(self) -> int: Returns: int: Rank of the previous stage. """ - assert not self.is_first_stage(), "Cannot get previous rank in the first stage." return self.prev_rank def get_next_rank(self) -> int: @@ -117,39 +108,8 @@ def get_next_rank(self) -> int: Returns: int: Rank of the next stage. """ - assert not self.is_last_stage(), "Cannot get next rank in the last stage." return self.next_rank - def set_num_virtual_stages(self, num_virtual_stages: int) -> None: - """Set the number of virtual stages. - - Args: - num_virtual_stages (int): Number of virtual stages. - """ - self.num_virtual_stages = num_virtual_stages - - def set_virtual_stage(self, virtual_stage: int) -> None: - """Set the virtual stage. - - Args: - virtual_stage (int): Virtual stage. - """ - self.virtual_stage = virtual_stage - - @contextmanager - def switch_virtual_stage(self, virtual_stage: int) -> None: - """A context manager to switch virtual stage. - - Args: - virtual_stage (int): Target virtual stage. - """ - old_stage = self.virtual_stage - try: - self.set_virtual_stage(virtual_stage) - yield - finally: - self.set_virtual_stage(old_stage) - def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: """Get the p2p process group between two ranks. The order of the two ranks does not matter. diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py new file mode 100644 index 000000000000..2ac31c8ca0d1 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -0,0 +1,161 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 8) + self.linear3 = nn.Linear(8, 8) + self.linear4 = nn.Linear(8, 8) + self.linear5 = nn.Linear(8, 8) + self.linear6 = nn.Linear(8, 8) + self.linear7 = nn.Linear(8, 8) + self.linear8 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + x = self.linear7(x) + x = self.linear8(x) + return x + + +def pp_linear_fwd(forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None): + + if stage_mgr.is_first_stage() and model_chunk_id == 0: + return {'input_obj': forward(data)} + elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + return forward(input_obj) + else: + return {'input_obj': forward(input_obj)} + + +@parameterize("num_micro_batches", [4, 8, 12]) +def examine_pp(num_micro_batches): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = num_micro_batches + BATCH_SIZE = num_micro_batches + NUM_CHUNKS = 2 + + # create model + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) + schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + + sharded_model = torch.nn.ModuleList() + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, + stage_mgr=stage_manager, + num_chunks=NUM_CHUNKS, + model_chunk_id=len(sharded_model)), sub_model._forward) + sharded_model.append(sub_model.cuda()) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if local_rank == 0: + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, + pp_optimizer, + iter(input_list), + criterion, + return_loss=True, + return_outputs=True) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret['loss']) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + else: + assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + else: + assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_pp() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58f74..6e0cd1998c11 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -49,15 +49,6 @@ def check_stage_manager(): next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] assert stage_manager.get_next_rank() == next_rank - # check virtual stage - stage_manager.set_num_virtual_stages(PP_SIZE * 2) - assert stage_manager.num_virtual_stages == PP_SIZE * 2 - stage_manager.set_virtual_stage(stage_manager.stage * 2) - assert stage_manager.virtual_stage == stage_manager.stage * 2 - with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): - assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 - assert stage_manager.virtual_stage == stage_manager.stage * 2 - # check p2p groups for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): if rank in [prev, cur]: From 7c8be770810835544e2652c6d053e77db83a0949 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:21:53 +0800 Subject: [PATCH 05/33] [shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460) * support gpt2 seq parallel with pp/dp/tp * fix a bug when waiting for stream done * delete unused gpt2_seq file --- .../booster/plugin/hybrid_parallel_plugin.py | 4 + colossalai/shardformer/layer/_operation.py | 2 + colossalai/shardformer/modeling/gpt2.py | 256 +++++++++++++++++- colossalai/shardformer/modeling/gpt2_seq.py | 222 --------------- colossalai/shardformer/policies/gpt2.py | 14 +- .../test_model/test_shard_gpt2.py | 10 +- 6 files changed, 268 insertions(+), 240 deletions(-) delete mode 100644 colossalai/shardformer/modeling/gpt2_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 00c714fe4612..155f72dc6db2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -235,6 +235,10 @@ def __init__(self, assert dist.get_world_size() % ( tp_size * pp_size ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + + if enable_sequence_parallelism: + assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + # TODO(ver217): support zero assert zero_stage == 0, 'zero is not support yet' self.tp_size = tp_size diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 13e563123d28..fc13aca79969 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -239,6 +239,7 @@ def backward(ctx, grad_output): output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() torch.cuda.current_stream().wait_stream(calculate_stream) + gather_handle.wait() reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) with torch.cuda.stream(calculate_stream): @@ -249,6 +250,7 @@ def backward(ctx, grad_output): grad_weight = grad_output.t().matmul(input_parallel) torch.cuda.current_stream().wait_stream(calculate_stream) + reducescatter_handle.wait() return output, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 47835d5d5468..722f0f52334b 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,6 +21,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig class GPT2PipelineForwards: @@ -47,7 +49,8 @@ def gpt2_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -159,6 +162,13 @@ def gpt2_model_forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): @@ -212,6 +222,12 @@ def custom_forward(*inputs): if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -257,7 +273,8 @@ def gpt2_lmhead_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -285,7 +302,8 @@ def gpt2_lmhead_model_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -335,7 +353,8 @@ def gpt2_double_heads_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -367,7 +386,8 @@ def gpt2_double_heads_model_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -421,7 +441,8 @@ def gpt2_for_question_answering_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -449,7 +470,8 @@ def gpt2_for_question_answering_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -508,7 +530,8 @@ def gpt2_for_token_classification_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -534,7 +557,8 @@ def gpt2_for_token_classification_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -578,7 +602,8 @@ def gpt2_for_sequence_classification_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -613,7 +638,8 @@ def gpt2_for_sequence_classification_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -696,7 +722,6 @@ def forward( output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: _, tgt_len, _ = hidden_states.size() - assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -753,3 +778,210 @@ def forward( return outputs return forward + + +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py deleted file mode 100644 index a6da96e7bf73..000000000000 --- a/colossalai/shardformer/modeling/gpt2_seq.py +++ /dev/null @@ -1,222 +0,0 @@ -# this code is modified from transformers.models.gpt2.modeling_gpt2 -# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670 - -from typing import Optional, Tuple, Union - -import torch -import torch.distributed as dist -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.utils import logging - -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward -from colossalai.shardformer.shard import ShardConfig - -logger = logging.get_logger(__name__) - - -# TODO: put all contents in `gpt2.py` and make it compatible with pipeline -def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 276d95660c4d..d34c0ae9fe64 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,8 +6,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward -from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -50,8 +49,6 @@ def module_policy(self): target_module=col_nn.DropoutForParallelInput, ), ]) - if self.shard_config.enable_sequence_parallelism: - policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -126,6 +123,7 @@ def module_policy(self): }) if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) @@ -169,7 +167,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 97295f72f4e1..0e29f1dd935a 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -105,10 +105,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_all_optimization': False, + 'enable_all_optimization': True, 'use_lazy_init': True, 'enable_sequence_parallelism': True, 'precision': 'fp32', From 0ecd71e041e517808097f09eafc97811f93d4235 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 18 Aug 2023 15:34:18 +0800 Subject: [PATCH 06/33] [shardformer] bloom support sequence parallel (#4465) [shardformer] bloom support sequence parallel --- colossalai/shardformer/modeling/bloom.py | 184 ++++++++++++++++++- colossalai/shardformer/policies/bloom.py | 24 ++- colossalai/shardformer/shard/shard_config.py | 1 + 3 files changed, 201 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 12276635ecfa..66f24dc6088b 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -23,6 +23,10 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -111,6 +115,7 @@ def bloom_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: @@ -205,6 +210,13 @@ def bloom_model_forward( past_key_values_length=past_key_values_length, ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx): @@ -248,6 +260,12 @@ def custom_forward(*inputs): all_self_attentions = all_self_attentions + \ (outputs[2 if use_cache else 1],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if stage_manager.is_last_stage(): # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -287,6 +305,7 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -327,7 +346,8 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -380,6 +400,7 @@ def bloom_for_sequence_classification_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ): r""" @@ -424,6 +445,7 @@ def bloom_for_sequence_classification_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -503,6 +525,7 @@ def bloom_for_token_classification_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ): r""" @@ -547,6 +570,7 @@ def bloom_for_token_classification_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -597,6 +621,7 @@ def bloom_for_question_answering_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -632,6 +657,7 @@ def bloom_for_question_answering_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -700,8 +726,7 @@ def forward( fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _ = hidden_states.size() - assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + batch_size, tgt_len, _ = query_layer.size() _, kv_length, _, _ = key_layer.size() @@ -896,3 +921,156 @@ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: return self.bloom_gelu_forward(x, bias) return forward + + +def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): + + from transformers import BloomModel + + def forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index b35764db3870..2727272d0867 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -12,6 +12,7 @@ BloomPipelineForwards, build_bloom_alibi_tensor_fn, get_bloom_flash_attention_forward, + get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, @@ -43,6 +44,7 @@ def module_policy(self): policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -53,11 +55,11 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, @@ -65,11 +67,11 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - ), + kwargs={'seq_parallel': use_sequence_parallel}), ]) policy[BloomModel] = ModulePolicyDescription( @@ -116,6 +118,12 @@ def module_policy(self): policy=policy, target_key=BloomBlock) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BloomModel) + if self.shard_config.enable_flash_attention: policy[BloomAttention] = ModulePolicyDescription(method_replacement={ 'forward': get_bloom_flash_attention_forward(), @@ -154,7 +162,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a36e878c623f..900f8475c71b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -58,3 +58,4 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True + self.enable_sequence_parallelism = True From a27e0bb494c1260678df0587419913340fda0c1d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 18 Aug 2023 18:04:55 +0800 Subject: [PATCH 07/33] [shardformer] bert support sequence parallel. (#4455) * [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel --- colossalai/shardformer/layer/_operation.py | 6 +- colossalai/shardformer/modeling/bert.py | 246 ++++++++++++++++++--- colossalai/shardformer/policies/bert.py | 24 +- 3 files changed, 234 insertions(+), 42 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index fc13aca79969..f1f48273ccd1 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -217,9 +217,7 @@ def backward(ctx, grad_output): # do all gather in default stream input_ = input_.contiguous() world_size = dist.get_world_size(process_group) - rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) # calculate gradient in calculate_stream @@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None): # all gather input_ = input_.contiguous() - rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=process_group) # concat diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 5bd1c531cc68..d88661953a29 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -29,6 +29,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward class BertPipelineForwards: @@ -56,6 +58,7 @@ def bert_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # TODO(jianghai): add explaination of the output here. r""" @@ -177,6 +180,14 @@ def bert_model_forward( start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -223,11 +234,17 @@ def custom_forward(*inputs): all_cross_attentions = all_cross_attentions + \ (layer_outputs[2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None + sequence_output = hidden_states if hidden_states is not None else None if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -268,6 +285,7 @@ def bert_for_pretraining_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) @@ -294,6 +312,7 @@ def bert_for_pretraining_forward( stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -350,6 +369,7 @@ def bert_lm_head_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -404,7 +424,8 @@ def bert_lm_head_model_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -457,6 +478,7 @@ def bert_for_masked_lm_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -491,6 +513,7 @@ def bert_for_masked_lm_forward( hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): @@ -532,6 +555,7 @@ def bert_for_next_sentence_prediction_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **kwargs, ): # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -594,7 +618,8 @@ def bert_for_next_sentence_prediction_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -636,6 +661,7 @@ def bert_for_sequence_classification_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -666,7 +692,8 @@ def bert_for_sequence_classification_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -726,6 +753,7 @@ def bert_for_token_classification_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -742,21 +770,20 @@ def bert_for_token_classification_forward( logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -799,6 +826,7 @@ def bert_for_multiple_choice_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -843,6 +871,7 @@ def bert_for_multiple_choice_forward( hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -886,6 +915,7 @@ def bert_for_question_answering_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # NOTE: the arg start_position and end_position are used only for the last stage r""" @@ -909,21 +939,20 @@ def bert_for_question_answering_forward( logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -1101,3 +1130,150 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T return hidden_states return forward + + +def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + embedding_output = split_forward_gather_backward(embedding_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + sequence_output = gather_forward_split_backward(sequence_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ace9ada3904f..fe091c658682 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -10,6 +10,7 @@ from .._utils import getattr_, setattr_ from ..modeling.bert import ( BertPipelineForwards, + bert_sequence_parallel_forward_fn, get_bert_flash_attention_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -47,13 +48,14 @@ def module_policy(self): from transformers.models.bert.modeling_bert import ( BertEmbeddings, BertLayer, + BertModel, BertOutput, BertSelfAttention, BertSelfOutput, ) policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -69,14 +71,17 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -85,6 +90,7 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -93,10 +99,12 @@ def module_policy(self): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -115,6 +123,12 @@ def module_policy(self): ) ]) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BertModel) + # optimization configuration if self.shard_config.enable_fused_normalization: # Handle bert layer @@ -205,7 +219,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) From 8739aa7fa01a9c04743dea813e2cc210e30dd77f Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 18 Aug 2023 21:29:25 +0800 Subject: [PATCH 08/33] [shardformer] Pipeline/whisper (#4456) * add some base tests and policies * finish whisper base model * add conditional generation * finish basic tests * whisper * finish whisper * finish whisper * del useless whisper test * fix * add argmin to replace * finish revision --- colossalai/shardformer/modeling/whisper.py | 715 +++++++++++++++++- colossalai/shardformer/policies/blip2.py | 9 - colossalai/shardformer/policies/t5.py | 9 +- colossalai/shardformer/policies/whisper.py | 243 +++++- .../test_t5_pipeline_utils.py | 39 + .../test_whisper_pipeline_utils.py | 44 ++ .../test_model/test_shard_llama.py | 2 + .../test_model/test_shard_whisper.py | 152 +++- 8 files changed, 1158 insertions(+), 55 deletions(-) create mode 100644 tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py create mode 100644 tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 0a16c6f788da..62f8f7b4763e 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -1,7 +1,26 @@ -from typing import Optional, Tuple +import logging +import random +from typing import Dict, List, Optional, Set, Tuple, Union import torch from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, +) +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperForAudioClassification, + WhisperForConditionalGeneration, + WhisperModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): @@ -247,3 +266,697 @@ def forward( return outputs return forward + + +class WhisperPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + @staticmethod + def whisper_encoder_forward( + self: WhisperEncoder, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + + stage = stage_manager.stage + at_first_stage = (stage == 0) + at_last_stage = (stage == decoder_starting_stage - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process inputs if at the first stage of encoder. + if at_first_stage: + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + encoder_layer = self.layers[idx] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + else: + return {'hidden_states': hidden_states, 'head_mask': head_mask} + + @staticmethod + def whisper_decoder_forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + stage = stage_manager.stage + at_first_stage = (stage == decoder_starting_stage) + at_last_stage = (stage == stage_manager.num_stages - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if at_first_stage: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + else: + + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + input_shape = hidden_states.size()[:-1] + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states, + past_key_values_length) + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + else: + return { + 'head_mask': head_mask, + 'cross_attn_head_mask': cross_attn_head_mask, + 'hidden_states': hidden_states, + } + + @staticmethod + def whisper_model_forward( + self: WhisperModel, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Returns: + + Example: + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + in_decoder = stage_manager.stage >= decoder_starting_stage + if not in_decoder: + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_hidden_states': encoder_outputs[0]} + else: + return encoder_outputs + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded Whisper forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def whisper_for_conditional_generation_forward( + self: WhisperForConditionalGeneration, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + in_decoder = stage_manager.stage >= decoder_starting_stage + at_last_decoder_stage = stage_manager.is_last_stage() + outputs = WhisperPipelineForwards.whisper_model_forward(self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if not in_decoder: + return outputs + + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + outputs['encoder_hidden_states'] = encoder_hidden_states + return outputs + + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def whisper_for_audio_classification_forward( + self: WhisperForAudioClassification, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. + Please refer to original code of transformers for more details. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # audio_classification only holds encoder + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + if not stage_manager.is_last_stage(): + return encoder_outputs + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 50356302e93e..3610e2c4109b 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -304,15 +304,6 @@ def module_policy(self): return policy def postprocess(self): - binding_map = { - 'language_model.model.decoder.embed_tokens': 'language_model.lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - return self.model diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 2ef52c214c6b..651883d35b87 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,6 +1,7 @@ from functools import partial from typing import Callable, Dict, List, Optional, Tuple +import numpy as np from torch import Tensor, nn from colossalai.shardformer.layer import ( @@ -228,13 +229,7 @@ def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int, def objective(num_encoder_stages): return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) - num_encoder_stages = 0 - optimal_diff = 2**31 - 1 - for i in range(1, num_stages): - attempt = objective(i) - if attempt < optimal_diff: - num_encoder_stages = i - optimal_diff = attempt + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 2ac7a49fd27b..a33f929f1e48 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,10 +1,16 @@ +from functools import partial +from typing import Callable, Dict, List, Tuple + +import numpy as np import torch.nn as nn +from torch import Tensor import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.whisper import ( + WhisperPipelineForwards, get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward, get_whisper_flash_attention_forward, @@ -12,7 +18,8 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' + 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', + 'WhisperForAudioClassificationPolicy' ] @@ -223,6 +230,146 @@ def add_lm_head_policy(self, base_policy): def postprocess(self): return self.model + @staticmethod + def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int, + num_stages: int) -> Tuple[List[int], int]: + """ + Distribute whisper layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for whisper must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_whisper_stage_index(layers_per_stage: List[int], stage: int, + decoder_starting_stage: int) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + # whisper for audio classification holds encoder only + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + # WhisperModel class WhisperModelPolicy(WhisperPolicy): @@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers import WhisperModel + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperModel, + new_forward=WhisperPipelineForwards.whisper_model_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in whisper model" + return [] + # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): @@ -238,20 +403,82 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + from transformers import WhisperForConditionalGeneration + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy) + return policy def postprocess(self): - binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) return self.model + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + model = module.model + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers, + stage_manager.num_stages) + shared_params = [] + shared_embedding = {} + if id(module.proj_out) == id(model.decoder.embed_tokens): + shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens + shared_embedding[stage_manager.num_stages - 1] = module.proj_out + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) + return shared_params + return [] + # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers import WhisperForAudioClassification + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py new file mode 100644 index 000000000000..0cbb852b97a0 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -0,0 +1,39 @@ +from colossalai.shardformer.policies.t5 import T5BasePolicy + + +def test_t5_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_t5_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py new file mode 100644 index 000000000000..395519e97898 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -0,0 +1,44 @@ +from colossalai.shardformer.policies.whisper import WhisperPolicy + + +def test_whisper_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_whisper_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end + + +if __name__ == '__main__': + test_whisper_pipeline_distribution() + test_whisper_pipeline_layers() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a433567b3702..ec5578a765c5 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -143,6 +144,7 @@ def run_llama_test(test_config): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 9b38ae07b1d6..90e007e34de8 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -3,6 +3,8 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -11,55 +13,145 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5) - - # do backward - org_loss.backward() - shard_loss.backward() - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model - sharded_whisper = sharded_model.model + sharded_whisper = sharded_model.unwrap().model else: whisper = org_model - sharded_whisper = sharded_model + sharded_whisper = sharded_model.unwrap() # check grad if org_model.__class__.__name__ == 'WhisperForAudioClassification': col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] else: - col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj'] - row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj'] - check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + col_layer_for_check = [ + 'encoder.layers[0].self_attn.q_proj', + # 'decoder.layers[0].self_attn.q_proj' + ] + row_layer_for_check = [ + 'encoder.layers[0].self_attn.out_proj', + #'decoder.layers[0].self_attn.out_proj' + ] + + # check weights and gradients + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) + check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + check_weight(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): + +# TODO(jianghai) fix fp16 +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', +}]) +def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, - enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification': + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_whisper(): - spawn(check_whisper, 2) + spawn(check_whisper, 4) if __name__ == "__main__": From 1c7df566e23d5b94512f0777e0475df0a0ae1072 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 21 Aug 2023 12:04:52 +0800 Subject: [PATCH 09/33] [shardformer] support tp+zero for shardformer (#4472) * support tp+zero/input type cast for hybridplugin * add tp+zero tests * fix bucket arguments --- .../booster/plugin/hybrid_parallel_plugin.py | 89 +++++++++++++------ .../test_model/test_shard_bert.py | 12 ++- .../test_model/test_shard_bloom.py | 10 ++- .../test_model/test_shard_chatglm.py | 10 ++- .../test_model/test_shard_gpt2.py | 10 ++- .../test_model/test_shard_llama.py | 10 ++- .../test_model/test_shard_opt.py | 10 ++- .../test_model/test_shard_t5.py | 12 ++- .../test_model/test_shard_vit.py | 10 ++- 9 files changed, 136 insertions(+), 37 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 155f72dc6db2..016323ae7821 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,6 @@ import random from contextlib import nullcontext +from functools import partial from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np @@ -10,6 +11,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -27,32 +29,49 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, ddp_config: dict) -> None: + self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group + shardformer = ShardFormer(shard_config) module, self.shared_params = shardformer.optimize(module) - # TODO(ver217): add input type cast + + # setting process groups for shared parameters self.shared_param_process_groups = [] for shared_param in self.shared_params: if len(shared_param) > 0: self.shared_param_process_groups.append( self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + + # setting mixed_precision + self.mixed_precision = None if precision == 'fp16': - module = module.half().cuda() + self.mixed_precision = torch.float16 elif precision == 'bf16': - module = module.to(dtype=torch.bfloat16).cuda() - else: - module = module.cuda() # train without AMP + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.cuda() - if use_ddp: + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + # setting ddp configs + if use_ddp: # convert model to sync bn module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) - # wrap the model with PyTorch DDP module = DDP(module, process_group=dp_group, **ddp_config) @@ -78,6 +97,12 @@ def sync_grads(self): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + def unwrap(self): module = super().unwrap() if isinstance(module, DDP): @@ -180,7 +205,6 @@ class HybridParallelPlugin(PipelinePluginBase): Defaults to 'fp16'. zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. When set to 0, ZeRO will not be used. Defaults to 0. - cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. Currently all the optimization methods include fused normalization, flash attention and JIT. Defaults to False. @@ -196,12 +220,16 @@ class HybridParallelPlugin(PipelinePluginBase): hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. - bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. - check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. - static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. """ def __init__(self, @@ -209,7 +237,6 @@ def __init__(self, pp_size: int, precision: str = 'fp16', zero_stage: int = 0, - cpu_offload: bool = False, enable_all_optimization: bool = False, enable_fused_normalization: bool = False, enable_flash_attention: bool = False, @@ -224,12 +251,16 @@ def __init__(self, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0, - broadcast_buffers=True, - bucket_cap_mb=25, - find_unused_parameters=False, - check_reduction=False, - gradient_as_bucket_view=False, - static_graph=False) -> None: + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True) -> None: super().__init__() assert dist.get_world_size() % ( @@ -239,8 +270,6 @@ def __init__(self, if enable_sequence_parallelism: assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' - # TODO(ver217): support zero - assert zero_stage == 0, 'zero is not support yet' self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -282,11 +311,18 @@ def __init__(self, ) self.ddp_config = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, + bucket_cap_mb=ddp_bucket_cap_mb, find_unused_parameters=find_unused_parameters, check_reduction=check_reduction, gradient_as_bucket_view=gradient_as_bucket_view, static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + self.max_norm = max_norm @property @@ -337,15 +373,16 @@ def configure( model, use_pipeline=self.enable_pipeline_parallelism) else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, - partition_grad=(self.zero_stage == 2), - cpu_offload=self.cpu_offload, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, + **self.zero_config, **self.amp_config) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 49de9cc0311c..c967017041af 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -56,9 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) - #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) @@ -101,6 +99,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index af014a8585b5..bd87be8b7b65 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] col_layer_for_check = ['h[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-5 else: @@ -101,6 +101,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 210f775b540d..64732e06bbc4 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] col_layer_for_check = ['encoder.layers[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: @@ -125,6 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_chatglm_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0e29f1dd935a..c776a80d8b65 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -56,7 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: @@ -120,6 +120,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'use_lazy_init': True, 'enable_sequence_parallelism': True, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ec5578a765c5..7140c4666861 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -60,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-4 else: @@ -135,6 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 2fb14903b6a9..e6faafdaea4a 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -58,7 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: @@ -127,6 +127,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 234ce812a08c..599f5a80d8ba 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -55,12 +55,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - # check weights and gradients + # check grad if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() @@ -110,6 +110,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index b9d303841215..b27add24cd09 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: @@ -124,6 +124,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config): From 5545114fd84c8aa39b18aa0ad8816ddbc6dab360 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 22 Aug 2023 14:13:31 +0800 Subject: [PATCH 10/33] rename chatglm to chatglm2 (#4484) --- colossalai/shardformer/modeling/{chatglm.py => chatglm2.py} | 0 colossalai/shardformer/policies/auto_policy.py | 4 ++-- colossalai/shardformer/policies/{chatglm.py => chatglm2.py} | 4 ++-- tests/kit/model_zoo/transformers/__init__.py | 2 +- tests/kit/model_zoo/transformers/{chatglm.py => chatglm2.py} | 0 .../{test_shard_chatglm.py => test_shard_chatglm2.py} | 0 6 files changed, 5 insertions(+), 5 deletions(-) rename colossalai/shardformer/modeling/{chatglm.py => chatglm2.py} (100%) rename colossalai/shardformer/policies/{chatglm.py => chatglm2.py} (98%) rename tests/kit/model_zoo/transformers/{chatglm.py => chatglm2.py} (100%) rename tests/test_shardformer/test_model/{test_shard_chatglm.py => test_shard_chatglm2.py} (100%) diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm2.py similarity index 100% rename from colossalai/shardformer/modeling/chatglm.py rename to colossalai/shardformer/modeling/chatglm2.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index eec339c02872..2fe49f0d5afe 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -125,9 +125,9 @@ class PolicyLocation: # ChatGLM "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"), "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": - PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm2.py similarity index 98% rename from colossalai/shardformer/policies/chatglm.py rename to colossalai/shardformer/policies/chatglm2.py index e6b458936637..a15aa856dcb8 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -7,7 +7,7 @@ import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -15,7 +15,7 @@ GLMBlock, ) -from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 823ca032fc30..2a492361b13b 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,7 @@ from .bert import * from .blip2 import * from .bloom import * -from .chatglm import * +from .chatglm2 import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm2.py similarity index 100% rename from tests/kit/model_zoo/transformers/chatglm.py rename to tests/kit/model_zoo/transformers/chatglm2.py diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py similarity index 100% rename from tests/test_shardformer/test_model/test_shard_chatglm.py rename to tests/test_shardformer/test_model/test_shard_chatglm2.py From 351351a36eb9d11e5bdb3610b0d3705055d90e7d Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 22 Aug 2023 17:35:35 +0800 Subject: [PATCH 11/33] [shardformer/sequence parallel] not support opt of seq-parallel, add warning and fix a bug in gpt2 pp (#4488) --- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/policies/opt.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 722f0f52334b..8ed367b25349 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -148,7 +148,7 @@ def gpt2_model_forward( if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) + hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ba6036bd0658..58663553b922 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -39,6 +40,9 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ From 59e252ecdbab0fe56fd3bacc9833188fe5285d02 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 22 Aug 2023 23:59:31 +0800 Subject: [PATCH 12/33] [shardformer] chatglm support sequence parallel (#4482) * [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix --- colossalai/shardformer/layer/linear.py | 10 +- colossalai/shardformer/modeling/chatglm2.py | 135 ++++++++++++++++--- colossalai/shardformer/policies/bert.py | 18 ++- colossalai/shardformer/policies/blip2.py | 23 ++-- colossalai/shardformer/policies/bloom.py | 26 ++-- colossalai/shardformer/policies/chatglm2.py | 101 +++++++++----- colossalai/shardformer/policies/gpt2.py | 6 +- colossalai/shardformer/policies/llama.py | 6 +- colossalai/shardformer/policies/sam.py | 12 +- colossalai/shardformer/policies/vit.py | 12 +- tests/kit/model_zoo/transformers/chatglm2.py | 4 +- 11 files changed, 259 insertions(+), 94 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 69ac3ad2581a..81c3f973fd49 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -74,6 +74,7 @@ def __init__(self, process_group: ProcessGroup = None, gather_output: bool = False, seq_parallel: bool = False, + seq_parallel_dim: int = 1, overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -87,6 +88,7 @@ def __init__(self, self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -190,7 +192,8 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1, self.overlap) + self.process_group, True, + self.seq_parallel_dim, self.overlap) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -236,6 +239,7 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, seq_parallel: bool = False, + seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -254,6 +258,7 @@ def __init__(self, self.skip_bias_add = skip_bias_add self.process_group = process_group self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -390,7 +395,8 @@ def forward(self, input_: Tensor) -> Tensor: else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, + self.seq_parallel_dim) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 409e2e1f5497..16dcf87c8cfc 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,6 +9,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -146,6 +148,7 @@ def chatglm_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) output_hidden_states = (output_hidden_states @@ -198,6 +201,11 @@ def chatglm_model_forward( all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] + + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -214,6 +222,11 @@ def chatglm_model_forward( hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) + + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -233,23 +246,22 @@ def chatglm_model_forward( return {'hidden_states': hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): + def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None): logger = logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) @@ -266,6 +278,7 @@ def chatglm_for_conditional_generation_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -296,3 +309,91 @@ def chatglm_for_conditional_generation_forward( ) else: return transformer_outputs + + +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] + inputs_embeds = split_forward_gather_backward(inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fe091c658682..19dd95fd6b6a 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -155,20 +155,26 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bert_flash_attention_forward(), - }) + }, + policy=policy, + target_key=BertSelfAttention) # use jit operator if self.shard_config.enable_jit_fused: - policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_self_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BertOutput] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BertSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=BertOutput) return policy diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 3610e2c4109b..2e5388ab0490 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -285,21 +285,26 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_blip2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=Blip2Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( - method_replacement={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_blip2_QFormer_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=Blip2QFormerOutput) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2727272d0867..21db13f6e441 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -125,25 +125,33 @@ def module_policy(self): target_key=BloomModel) if self.shard_config.enable_flash_attention: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func() - }) + 'dropout_add': get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention) # enable jit fused operator if self.shard_config.enable_jit_fused: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomAttention) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomMLP) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_gelu_forward(), 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }) + }, + policy=policy, + target_key=BloomGelu) return policy diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index a15aa856dcb8..b0d684a67dce 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -15,7 +15,11 @@ GLMBlock, ) -from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_forward_fn, + get_flash_core_attention_forward, + get_jit_fused_glm_block_forward, +) from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -45,8 +49,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ SubModuleReplacementDescription( @@ -55,36 +59,42 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) ]) - policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.core_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription(suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: @@ -124,16 +134,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # use flash attention if self.shard_config.enable_flash_attention: - policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), - }) + }, + policy=policy, + target_key=CoreAttention) + + # use sequence parallel + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=ChatGLMModel) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_glm_block_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=GLMBlock) return policy @@ -178,7 +199,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d34c0ae9fe64..acae2630942b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -118,9 +118,11 @@ def module_policy(self): target_key=GPT2Block) if self.shard_config.enable_flash_attention: - policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_gpt2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=GPT2Attention) if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be8fa..ccf7764079a9 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -105,9 +105,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaModel) if self.shard_config.enable_flash_attention: - policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_llama_flash_attention_forward(), - }) + }, + policy=policy, + target_key=LlamaAttention) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index b1eba0432b49..9753d5a737b9 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -199,12 +199,16 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[SamAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_sam_flash_attention_forward(), - }) - policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=SamAttention) + self.append_or_create_method_replacement(description={ 'forward': get_sam_vision_flash_attention_forward(), - }) + }, + policy=policy, + target_key=SamVisionAttention) return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 617720ee7950..757bab95f273 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -90,16 +90,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # use flash attention if self.shard_config.enable_flash_attention: - policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_vit_flash_self_attention_forward(), - }) + }, + policy=policy, + target_key=ViTSelfAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[ViTOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_vit_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=ViTOutput) return policy def new_model_class(self): diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index c6473ee2a025..d543df00bdfa 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -12,8 +12,8 @@ def data_gen(): - input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask) From e04436a82aa847db166cb181053c290c8a150496 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 23 Aug 2023 15:05:24 +0800 Subject: [PATCH 13/33] [shardformer] tests for 3d parallel (#4493) --- tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_bert.py | 36 ++++++++++++++++ .../test_model/test_shard_bloom.py | 38 +++++++++++++++++ .../test_model/test_shard_chatglm2.py | 35 ++++++++++++++++ .../test_model/test_shard_gpt2.py | 36 ++++++++++++++++ .../test_model/test_shard_llama.py | 37 ++++++++++++++++- .../test_model/test_shard_opt.py | 35 ++++++++++++++++ .../test_model/test_shard_t5.py | 35 ++++++++++++++++ .../test_model/test_shard_vit.py | 35 ++++++++++++++++ .../test_model/test_shard_whisper.py | 41 +++++++++++++++++-- 10 files changed, 324 insertions(+), 5 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 789b3b24e696..811471bec3c8 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -245,7 +245,6 @@ def check_grad(org_model: Module, org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index c967017041af..76f8c0541de5 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -120,12 +120,40 @@ def run_bert_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_bert_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_bert_test() +def check_bert_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -133,5 +161,13 @@ def test_bert(): spawn(check_bert, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert_3d(): + spawn(check_bert_3d, 8) + + if __name__ == "__main__": test_bert() + test_bert_3d() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index bd87be8b7b65..0e236fd47934 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,6 +3,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -118,6 +119,29 @@ def run_bloom_test(test_config): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_bloom_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -127,6 +151,12 @@ def check_bloom(rank, world_size, port): run_bloom_test() +def check_bloom_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -134,5 +164,13 @@ def test_bloom(): spawn(check_bloom, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom_3d(): + spawn(check_bloom_3d, 8) + + if __name__ == "__main__": test_bloom() + test_bloom_3d() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 64732e06bbc4..a8957d8d3f22 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -145,12 +145,39 @@ def run_chatglm_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_chatglm_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_chatglm(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_chatglm_test() +def check_chatglm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -158,5 +185,13 @@ def test_chatglm(): spawn(check_chatglm, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm_3d(): + spawn(check_chatglm_3d, 8) + + if __name__ == "__main__": test_chatglm() + test_chatglm_3d() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c776a80d8b65..85d66e493e03 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -141,12 +141,40 @@ def run_gpt2_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +@clear_cache_before_run() +def run_gpt2_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_gpt2(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_gpt2_test() +def check_gpt2_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -154,5 +182,13 @@ def test_gpt2(): spawn(check_gpt2, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2_3d(): + spawn(check_gpt2_3d, 8) + + if __name__ == "__main__": test_gpt2() + test_gpt2_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 7140c4666861..485d2685e8f4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -56,7 +56,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # unwrap model llama_model = unwrap_model(org_model, 'LlamaModel', 'model') shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') - # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] @@ -156,12 +155,40 @@ def run_llama_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_llama_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_llama_test() +def check_llama_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -169,5 +196,13 @@ def test_llama(): spawn(check_llama, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) + + if __name__ == "__main__": test_llama() + test_llama_3d() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index e6faafdaea4a..ad344585e8ce 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -146,12 +146,39 @@ def run_opt_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_opt_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_opt_test() +def check_opt_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_opt_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -159,5 +186,13 @@ def test_OPTModel(): spawn(check_OPTModel, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_opt_3d(): + spawn(check_opt_3d, 8) + + if __name__ == '__main__': test_OPTModel() + test_opt_3d() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 599f5a80d8ba..a853f024deb2 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -137,12 +137,39 @@ def run_t5_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_t5_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_t5(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_t5_test() +def check_t5_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -150,5 +177,13 @@ def test_t5(): spawn(check_t5, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5_3d(): + spawn(check_t5_3d, 8) + + if __name__ == "__main__": test_t5() + test_t5_3d() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index b27add24cd09..0b092966cfd8 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -146,12 +146,39 @@ def run_vit_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_vit_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_vit(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_vit_test() +def check_vit_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -159,5 +186,13 @@ def test_vit(): spawn(check_vit, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit_3d(): + spawn(check_vit_3d, 8) + + if __name__ == "__main__": test_vit() + test_vit_3d() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 90e007e34de8..6445b314dc97 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -82,8 +82,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() org_optimizer.step() @@ -99,7 +99,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, tp_group, atol=atol, rtol=rtol, - dim=0, + dim=1, verbose=False) check_weight(whisper, sharded_whisper, @@ -155,12 +155,39 @@ def run_whisper_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_whisper_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_whisper(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_whisper_test() +def check_whisper_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_whisper_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -168,5 +195,13 @@ def test_whisper(): spawn(check_whisper, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper_3d(): + spawn(check_whisper_3d, 8) + + if __name__ == "__main__": test_whisper() + test_whisper_3d() From 3353e55c80d22c765314ca4f4886d61f0a58cdd7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 24 Aug 2023 15:50:02 +0800 Subject: [PATCH 14/33] [shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498) * [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * activate checks --- colossalai/shardformer/modeling/bert.py | 6 +++++ colossalai/shardformer/policies/llama.py | 5 ++++ colossalai/shardformer/policies/opt.py | 12 ++++++--- colossalai/shardformer/policies/t5.py | 5 ++++ colossalai/shardformer/policies/vit.py | 5 ++++ colossalai/shardformer/policies/whisper.py | 27 +++++++++---------- .../test_model/test_shard_whisper.py | 7 ++--- 7 files changed, 46 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index d88661953a29..30855a622adb 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -187,6 +187,9 @@ def bert_model_forward( hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -1241,6 +1244,9 @@ def forward( embedding_output = split_forward_gather_backward(embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) encoder_outputs = self.encoder( embedding_output, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ccf7764079a9..c417e5d017bd 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -35,6 +36,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement={ diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 58663553b922..abe491bfaace 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -104,16 +104,20 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_opt_flash_attention_forward(), - }) + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_opt_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 651883d35b87..192a1b8472fc 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Optional, Tuple @@ -59,6 +60,10 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 757bab95f273..b4fb8692e684 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Dict, List, Union import torch.nn as nn @@ -32,6 +33,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index a33f929f1e48..bffb624d0d1a 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -33,7 +34,6 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - # TODO: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: @@ -52,6 +52,14 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": @@ -198,20 +206,11 @@ def module_policy(self): # enable flash attention if self.shard_config.enable_flash_attention: - policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_whisper_flash_attention_forward(), - }) - - # use jit fused operator - if self.shard_config.enable_jit_fused: - policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=WhisperAttention) return policy diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6445b314dc97..011fb8d238cc 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights and gradients if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # TODO(jianghai) fix fp16 +#TODO fix WhisperForConditionalGeneration enable jit fused operator @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, From de8a65babcf3bdf50fd1a60ff0baabe3e4f7803e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 25 Aug 2023 19:41:24 +0800 Subject: [PATCH 15/33] [shardformer] opt fix. (#4514) * [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * activate checks * [Test] test ci * test ci * test ci * test ci * test ci * test ci * test ci * fix --- colossalai/shardformer/policies/opt.py | 26 +++++++++---------- .../test_model/test_shard_opt.py | 1 - .../test_model/test_shard_whisper.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index abe491bfaace..be9d1c58b79e 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -103,21 +103,21 @@ def module_policy(self): target_key=OPTDecoderLayer) # use flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_opt_flash_attention_forward(), - }, - policy=policy, - target_key=OPTAttention) + # if self.shard_config.enable_flash_attention: + # self.append_or_create_method_replacement(description={ + # 'forward': get_opt_flash_attention_forward(), + # }, + # policy=policy, + # target_key=OPTAttention) # use jit fused operator - if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_opt_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=OPTDecoderLayer) + # if self.shard_config.enable_jit_fused: + # self.append_or_create_method_replacement(description={ + # 'forward': get_jit_fused_opt_decoder_layer_forward(), + # 'dropout_add': get_jit_fused_dropout_add_func(), + # }, + # policy=policy, + # target_key=OPTDecoderLayer) return policy diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index ad344585e8ce..71483b752c34 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -137,7 +137,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'initial_scale': 1 }]) def run_opt_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 011fb8d238cc..6eaed7d37e47 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 2e-4, 2e-4 + atol, rtol = 5e-4, 5e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): From 44eab2b27f8f854d5fb050a3b5aa83e79effd0b6 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 25 Aug 2023 22:04:57 +0800 Subject: [PATCH 16/33] [shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506) * add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fix --- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/checkpoint_io/__init__.py | 3 +- .../hybrid_parallel_checkpoint_io.py | 316 ++++++++++++++++++ colossalai/checkpoint_io/utils.py | 58 +++- .../shardformer/layer/parallel_module.py | 9 +- colossalai/zero/gemini/gemini_ddp.py | 28 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 116 +++++++ 7 files changed, 496 insertions(+), 39 deletions(-) create mode 100644 colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py create mode 100644 tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 016323ae7821..c49b3e1823cd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -16,7 +16,7 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO +from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule @@ -292,6 +292,7 @@ def __init__(self, self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, @@ -460,7 +461,7 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return None + return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index c25048e25754..07b1f81dace6 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,5 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py new file mode 100644 index 000000000000..56a89bff75ca --- /dev/null +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -0,0 +1,316 @@ +import copy +import gc +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from colossalai.cluster import ProcessGroupMesh +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) + +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( + StateDictSharder, + calculate_tensor_size, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + get_shard_filename, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + save_param_groups, + save_state_dict, + save_state_dict_shards, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +class HypridParallelCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. + + Args: + dp_group (ProcessGroup): Process group along data parallel dimension. + pp_group (ProcessGroup): Process group along pipeline parallel dimension. + tp_group (ProcessGroup): Process group along tensor parallel dimension. + """ + + def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: + super().__init__() + self.dp_group = dp_group + self.pp_group = pp_group + self.tp_group = tp_group + self.dp_rank = dist.get_rank(self.dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.dp_size = dist.get_world_size(dp_group) + self.pp_size = dist.get_world_size(pp_group) + self.tp_size = dist.get_world_size(tp_group) + + @staticmethod + def _model_sharder(model: nn.Module, + prefix: str = '', + keep_vars: bool = False, + size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + @staticmethod + def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. + # TODO (Baizhou): Implement sharding feature of optimizer. + pass + + def save_sharded_model(self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_size == 0 are responsible for model saving. + state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + index_file_path (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + missing_keys = [] + + load_state_dict_into_model(model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True) + del state_dict + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + _load(extra_state_key) + + def save_sharded_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + pass + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + pass + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 8837776aee4d..d04159c54d5e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -13,7 +13,12 @@ from colossalai.interface import OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor.d_tensor import is_distributed_tensor +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): + """ + Gather the complete parameter for saving if passed in param is distributed. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + + # ====================================== -# Helper functions for saving shard file +# Helper classes and functions for saving shard file # ====================================== def unwrap_optimizer(optimizer: OptimizerWrapper): ''' @@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): return unwrapped_optim +class StateDictSharder: + + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size + + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", @@ -126,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] total_size = 0 for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair if not is_master: + del shard continue - shard, current_size = shard_pair shard_file = get_shard_filename(base_filename, idx) total_size = total_size + current_size for key in shard.keys(): @@ -137,6 +188,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + del shard return total_size diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index bda147b121ab..4f391920e29b 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, @@ -56,13 +57,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - destination[prefix + name] = to_global(param_) - elif is_customized_distributed_tensor(param_): - destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) - else: - destination[prefix + name] = param_ + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 08384ee82d0b..5aff91f03153 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -8,7 +8,7 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage @@ -657,7 +657,7 @@ def state_dict_shard(self, Yields: Iterator[OrderedDict]: A generator of state dict shard """ - sharder = _StateDictSharder(max_shard_size) + sharder = StateDictSharder(max_shard_size) # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() @@ -705,30 +705,6 @@ def state_dict_shard(self, yield sharder.current_block, sharder.current_block_size -class _StateDictSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - ret_block_size = 0 - - # before we return the current block and create a new block, - # we need to ensure that the current block is not empty - if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: - ret_block = self.current_block - ret_block_size = self.current_block_size - self.current_block = OrderedDict() - self.current_block_size = 0 - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block, ret_block_size - - class GeminiDDP(ZeroDDP): def __init__(self, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py new file mode 100644 index 000000000000..ea0922ef5dec --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -0,0 +1,116 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('shard', [True]) +@parameterize('model_name', ['transformers_gpt']) +@parameterize('size_per_shard', [32]) +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp32', +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}]) +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): + + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + data_iter = iter([data]) + output = booster.execute_pipeline(data_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + data = {k: v.cuda() for k, v in data.items()} + output = model(**data) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + # optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + + clear_layout_converter() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) From 376533a56411d3826df2a5b3aabc5471016496bf Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 28 Aug 2023 10:51:16 +0800 Subject: [PATCH 17/33] [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py --- colossalai/pipeline/schedule/one_f_one_b.py | 3 +- colossalai/zero/low_level/low_level_optim.py | 9 ++- .../test_model/test_shard_bert.py | 9 +++ .../test_model/test_shard_bloom.py | 9 +++ .../test_model/test_shard_gpt2.py | 9 +++ .../test_model/test_shard_llama.py | 9 +++ .../test_model/test_shard_opt.py | 9 +++ .../test_model/test_shard_t5.py | 9 +++ .../test_model/test_shard_vit.py | 11 ++- .../test_model/test_shard_whisper.py | 67 ++++++++++--------- 10 files changed, 109 insertions(+), 35 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index f5e4929aa7c8..0058873c21ba 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -128,11 +128,11 @@ def forward_step(self, Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ micro_batch = self.load_micro_batch() - # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: accum_loss.add_(loss.detach()) @@ -158,7 +158,6 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], # Retain the grad on the input_obj. tree_map(retain_grad, input_obj) - # Backward pass. if output_obj_grad is None: optimizer.backward(output_obj) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 64d6a5395120..a1e85e5b90f6 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -316,7 +316,6 @@ def _add_to_bucket(self, param, group_id): def backward(self, loss, retain_graph=False): assert not(self._partition_grads and not self.require_grad_sync), \ "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" - if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) @@ -333,6 +332,13 @@ def backward(self, loss, retain_graph=False): self.zero_grad() + def backward_by_grad(self, tensor, grad): + # in lower stage which grad is transfered by higher stage + # we need to pass the optim state down. + if self.mixed_precision_mixin is not None: + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient @@ -358,7 +364,6 @@ def zero_grad(self, set_to_none=True): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 76f8c0541de5..a15645a7f344 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 0e236fd47934..590eff642e2b 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 85d66e493e03..13458fc5420e 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 485d2685e8f4..8dc6376bfb90 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 71483b752c34..939b2d55566e 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index a853f024deb2..cd3d3d673132 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 0b092966cfd8..d40058bb73f7 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'ViTModel': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model @@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() +#TODO: num_microbatch size = 2 inf loss @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6eaed7d37e47..356ed6405f37 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -112,37 +112,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() - +#TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 -#TODO fix WhisperForConditionalGeneration enable jit fused operator -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', -}]) +@parameterize( + 'test_config', + [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + # whisper is not supported fp16 for now. + ]) def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From c554b7f559b592c4d358db677c87658b11a6341c Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:16:40 +0800 Subject: [PATCH 18/33] =?UTF-8?q?[shardformer/fix=20overlap=20bug]=20fix?= =?UTF-8?q?=20overlap=20bug,=20add=20overlap=20as=20an=20option=20in=20sha?= =?UTF-8?q?rdco=E2=80=A6=20(#4516)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom --- colossalai/shardformer/layer/_operation.py | 53 ++++++++----------- colossalai/shardformer/layer/linear.py | 2 +- colossalai/shardformer/policies/bert.py | 21 ++++++-- colossalai/shardformer/policies/bloom.py | 11 +++- colossalai/shardformer/policies/chatglm2.py | 4 +- colossalai/shardformer/shard/shard_config.py | 9 ++++ .../test_layer/test_linear_1d.py | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f1f48273ccd1..55d9413b9979 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -211,43 +211,36 @@ def backward(ctx, grad_output): handle.wait() else: - # create new stream for calculate the gradient - calculate_stream = torch.cuda.Stream() - - # do all gather in default stream input_ = input_.contiguous() world_size = dist.get_world_size(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) - - # calculate gradient in calculate_stream - with torch.cuda.stream(calculate_stream): - # calculate - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - # prepare data - input_list = [ - item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() - torch.cuda.current_stream().wait_stream(calculate_stream) + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished gather_handle.wait() + # do reduce-scatter in async way reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - with torch.cuda.stream(calculate_stream): - input_parallel = torch.cat(tensor_list, dim=dim).contiguous() - if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) - print(grad_output.shape, input_parallel.shape) - grad_weight = grad_output.t().matmul(input_parallel) - - torch.cuda.current_stream().wait_stream(calculate_stream) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = grad_output.t().matmul(input_parallel) + # wait until reduce-scatter finished reducescatter_handle.wait() return output, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 81c3f973fd49..111d51b3f8d8 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -75,7 +75,7 @@ def __init__(self, gather_output: bool = False, seq_parallel: bool = False, seq_parallel_dim: int = 1, - overlap: bool = False, + overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 19dd95fd6b6a..a141b7bd8fdf 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -56,6 +56,7 @@ def module_policy(self): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -71,17 +72,26 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -99,7 +109,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="output.dense", diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 21db13f6e441..7c418d02bcb6 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -45,6 +45,7 @@ def module_policy(self): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -55,7 +56,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={'seq_parallel': use_sequence_parallel}), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, @@ -67,7 +71,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={'seq_parallel': use_sequence_parallel}), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index b0d684a67dce..5bcbc2acc28e 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -50,6 +50,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ @@ -81,7 +82,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0 + 'seq_parallel_dim': 0, + 'overlap': overlap }), SubModuleReplacementDescription(suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 900f8475c71b..c5c3d185e950 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -20,6 +20,8 @@ class ShardConfig: enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -29,6 +31,7 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -41,6 +44,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): + if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: + raise ValueError( + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True") + if not self.enable_sequence_parallelism and self.enable_sequence_overlap: + raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") if not self.enable_tensor_parallelism: self._tensor_parallel_size = 1 else: @@ -59,3 +67,4 @@ def _turn_on_all_optimization(self): self.enable_flash_attention = True self.enable_jit_fused = True self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 3ad8f14b99e6..e6d86d533ed6 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -168,7 +168,7 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [False, True]) +@parameterize('overlap', [True]) def run_dist_linear_test(lazy_init, seq_parallel, overlap): check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_row(lazy_init, seq_parallel) From 0387a47e63520bf112f80d094b64e1ae5890d525 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 11:25:05 +0800 Subject: [PATCH 19/33] [shardformer] fix emerged bugs after updating transformers (#4526) --- colossalai/pipeline/schedule/_utils.py | 5 ++++- tests/test_shardformer/test_model/_utils.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 3ed9239272f1..5cd934b76822 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any: merged_data = [] for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): - merged_data.append(torch.cat(elem_batch, dim=0)) + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + merged_data.append(None) + else: + merged_data.append(torch.cat(elem_batch, dim=0)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 811471bec3c8..803afc48ac09 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor, sharded_hidden_state = sharded_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim) + pipeline_output = sharded_output['outputs'] + if isinstance(pipeline_output, List): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim) + else: + sharded_hidden_state = pipeline_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" From e241b74f24ac4efe4712bcefedfd7f14f3dd7b37 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 29 Aug 2023 18:30:50 +0800 Subject: [PATCH 20/33] [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code --- colossalai/shardformer/layer/_operation.py | 87 ++++++++++++----- .../shardformer/layer/qkv_fused_linear.py | 4 +- .../shardformer/policies/base_policy.py | 19 ---- colossalai/shardformer/policies/gpt2.py | 94 ++++++++++--------- .../test_gpt2_qkv_fused_linear_1d.py | 10 +- 5 files changed, 120 insertions(+), 94 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 55d9413b9979..f45ccc64bae5 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -291,12 +291,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim + ctx.overlap = overlap input_parallel = _gather(input_, dim, process_group) @@ -312,37 +313,70 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + overlap = ctx.overlap - # TODO: overlap SP input with gradient computation - input_parallel = _gather(input_, dim, process_group) + if not overlap: + input_parallel = _gather(input_, dim, process_group) - total_input = input_parallel - grad_input = grad_output.matmul(weight.T) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() - # TODO: overlap SP input with gradient computation - if ctx.async_grad_reduce_scatter: - # Asynchronous reduce-scatter + else: + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() - handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 - - grad_weight = total_input.t().matmul(grad_output) - grad_bias = grad_output.sum(dim=0) if use_bias else None + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() - if ctx.async_grad_reduce_scatter: - handle.wait() + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = input_parallel.t().matmul(grad_output) + # wait until reduce-scatter finished + reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None + return output, grad_weight, grad_bias, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -510,9 +544,10 @@ def linear_reducescatter_forward_gather_backward(input_, process_group, dim): return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim) + async_grad_reduce_scatter, dim, overlap) def gather_forward_split_backward(input_, dim, process_group): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index ccb2bf7ea4cc..5ce77805f9b8 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -177,6 +177,7 @@ def __init__(self, async_communication: bool = False, gather_output: bool = False, seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -190,6 +191,7 @@ def __init__(self, self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -308,7 +310,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel: input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1) + self.process_group, True, 1, self.overlap) else: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 7022a1cfd7a2..961c6a5259fe 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -226,22 +226,3 @@ def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] - - def append_seq_parallel_to_policy( - self, - suffix_list: List[str], - module_policy_description: ModulePolicyDescription, - ): - r""" - Append the sequence parallel policy to the policy for the given key. - - Args: - suffix_list (List[str]): the suffix list of the module to be parallelized - policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated - """ - - for sub_description in module_policy_description.sub_module_replacement: - if (sub_description.suffix in suffix_list): - if sub_description.kwargs is None: - sub_description.kwargs = {} - sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index acae2630942b..5093fd469af8 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -37,7 +37,8 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -50,47 +51,54 @@ def module_policy(self): ), ]) - policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GPT2Block] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -126,8 +134,6 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} - suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] - self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) return policy diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index ae6a1dc90dc5..4c0f884a7ed5 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): process_group=None, gather_output=True, seq_parallel=seq_parallel, - n_fused=3) + n_fused=3, + overlap=overlap) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel) +@parameterize('overlap', [True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) check_linear_conv_1d_row(lazy_init, seq_parallel) From d367b8878589449cd5410ac8c4da756de6313aad Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 30 Aug 2023 14:50:34 +0800 Subject: [PATCH 21/33] [shardformer] fix opt test hanging (#4521) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix --- colossalai/shardformer/policies/opt.py | 26 +++---- colossalai/shardformer/policies/t5.py | 25 ++++-- colossalai/shardformer/policies/whisper.py | 18 ++++- tests/test_shardformer/test_model/_utils.py | 52 +++++++++++++ .../test_model/test_shard_bert.py | 56 ++++++++++---- .../test_model/test_shard_bloom.py | 57 +++++++++----- .../test_model/test_shard_chatglm2.py | 76 +++++++++++-------- .../test_model/test_shard_gpt2.py | 59 +++++++++----- .../test_model/test_shard_llama.py | 75 ++++++++++-------- .../test_model/test_shard_opt.py | 74 ++++++++++-------- .../test_model/test_shard_t5.py | 50 +++++++----- .../test_model/test_shard_vit.py | 71 +++++++++-------- .../test_model/test_shard_whisper.py | 58 +++++++++----- 13 files changed, 460 insertions(+), 237 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index be9d1c58b79e..abe491bfaace 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -103,21 +103,21 @@ def module_policy(self): target_key=OPTDecoderLayer) # use flash attention - # if self.shard_config.enable_flash_attention: - # self.append_or_create_method_replacement(description={ - # 'forward': get_opt_flash_attention_forward(), - # }, - # policy=policy, - # target_key=OPTAttention) + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement(description={ + 'forward': get_opt_flash_attention_forward(), + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator - # if self.shard_config.enable_jit_fused: - # self.append_or_create_method_replacement(description={ - # 'forward': get_jit_fused_opt_decoder_layer_forward(), - # 'dropout_add': get_jit_fused_dropout_add_func(), - # }, - # policy=policy, - # target_key=OPTDecoderLayer) + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_opt_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 192a1b8472fc..92cbd3f72b83 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -184,24 +184,33 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[T5Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_t5_flash_attention_forward(), - }) + }, + policy=policy, + target_key=T5Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_T5_layer_ff_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerFF) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_self_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_cross_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=T5LayerCrossAttention) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index bffb624d0d1a..5d496f08e1db 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -56,9 +56,6 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") - if self.shard_config.enable_jit_fused: - self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ @@ -212,6 +209,21 @@ def module_policy(self): policy=policy, target_key=WhisperAttention) + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperDecoderLayer) + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_encoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer) + return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 803afc48ac09..72bb2b025ba4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -237,6 +237,43 @@ def check_weight(org_model: Module, f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" +def get_grad_tensors_for_check(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None): + + grad_to_check = {} + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[:org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + + grad_to_check[suffix] = { + "org_grad": org_grad.float(), + "shard_grad": shard_grad.float(), + "rtol": rtol, + "atol": atol + } + + return grad_to_check + + +# used by sam/blip2 def check_grad(org_model: Module, sharded_model: Module, layer_suffix: List[str], @@ -275,3 +312,18 @@ def unwrap_model(module: Module, if module.__class__.__name__ == base_model_class_name: return module return getattr(module, base_model_attribute_name, None) + + +def check_all_grad_tensors(check_tensors): + """ + "org_grad": tensor to be compared from the original model + "shard_grad": tensor to be compared from the sharded model + """ + for suffix, check_info in check_tensors.items(): + org_grad = check_info["org_grad"] + shard_grad = check_info["shard_grad"] + rtol = check_info["rtol"] + atol = check_info["atol"] + assert torch.allclose( + org_grad, shard_grad, atol=atol, rtol=rtol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index a15645a7f344..61881a1f90e7 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -10,10 +10,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -33,18 +34,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, output_transform_fn, criterion, booster) + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BertModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) bert = unwrap_model(org_model, 'BertModel', 'bert') sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') @@ -52,17 +44,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() + col_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'BertModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 else: @@ -70,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager is None or stage_manager.is_first_stage(): check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 590eff642e2b..f7ab94bc9aae 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -9,10 +9,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,35 +37,54 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BloomModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model bloom = unwrap_model(org_model, 'BloomModel', 'transformer') sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') - # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] col_layer_for_check = ['h[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-5 else: atol, rtol = 5e-3, 5e-3 - check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'BloomModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -72,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index a8957d8d3f22..c5a3e68f7b55 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -9,10 +9,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,51 +37,57 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'ChatGLMModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') - # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(chatglm_model, - shard_chatglm_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - - check_grad(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + col_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'ChatGLMModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -95,6 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 13458fc5420e..44914721c40e 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -9,10 +9,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,18 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'GPT2Model': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') @@ -55,18 +44,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] - # check grad + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() + col_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'GPT2Model': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -74,6 +94,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8dc6376bfb90..c9d5d3d08305 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -12,10 +12,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -41,49 +42,56 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'LlamaModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model llama_model = unwrap_model(org_model, 'LlamaModel', 'model') shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') - # check grad + row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-4 else: atol, rtol = 5e-3, 5e-3 - check_grad(llama_model, - shard_llama_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'LlamaModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -98,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 939b2d55566e..8c0432b37425 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -11,10 +11,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -40,49 +41,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'OPTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model opt_model = unwrap_model(org_model, 'OPTModel', 'model') shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') - # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: - atol, rtol = 3e-2, 3e-2 - check_grad(opt_model, - shard_opt_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + atol, rtol = 4e-2, 4e-2 + row_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'OPTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-3, 1e-3 @@ -97,6 +104,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index cd3d3d673132..29367031e820 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -10,10 +10,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -37,42 +38,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ != 'T5ForConditionalGeneration': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model t5 = unwrap_model(org_model) sharded_t5 = unwrap_model(sharded_model) row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - # check grad + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(t5, + sharded_t5, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ != 'T5ForConditionalGeneration': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 + atol, rtol = 5e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index d40058bb73f7..2980c6eeafba 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -9,10 +9,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,17 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'ViTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model vit_model = unwrap_model(org_model, 'ViTModel', 'vit') shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') @@ -54,31 +44,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(vit_model, - shard_vit_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'ViTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -93,6 +101,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 356ed6405f37..a55753018300 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -15,10 +15,11 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, ) @@ -41,18 +42,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 2e-4, 2e-4 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'WhisperModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model @@ -75,19 +64,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, #'decoder.layers[0].self_attn.out_proj' ] - # check weights and gradients + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1) + col_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-4, 5e-4 else: @@ -110,8 +128,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() + #TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 @parameterize( From ec18fc7340f99693f2436e91e1dea99342f476d5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 30 Aug 2023 21:29:18 +0800 Subject: [PATCH 22/33] [shardformer] support pp+tp+zero1 tests (#4531) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 --- colossalai/zero/low_level/low_level_optim.py | 15 +++++++++++++-- .../test_model/test_shard_bert.py | 9 +++++++++ .../test_model/test_shard_bloom.py | 10 ++++++++++ .../test_model/test_shard_chatglm2.py | 10 ++++++++++ .../test_model/test_shard_gpt2.py | 10 ++++++++++ .../test_model/test_shard_llama.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_opt.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_t5.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_vit.py | 9 +++++++++ .../test_model/test_shard_whisper.py | 11 ++++++++++- 10 files changed, 101 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a1e85e5b90f6..85ac9eb48598 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -333,12 +333,23 @@ def backward(self, loss, retain_graph=False): self.zero_grad() def backward_by_grad(self, tensor, grad): - # in lower stage which grad is transfered by higher stage - # we need to pass the optim state down. + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) torch.autograd.backward(tensor, grad) + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + + self.zero_grad() + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 61881a1f90e7..0855e2248710 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -163,6 +163,15 @@ def run_bert_test(test_config): 'enable_all_optimization': False, 'use_lazy_init': False, 'precision': 'fp32', + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, 'initial_scale': 1, }, ]) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index f7ab94bc9aae..c9ee690c86dc 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -165,6 +165,16 @@ def run_bloom_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_bloom_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index c5a3e68f7b55..05ca05dea4d6 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -165,6 +165,16 @@ def run_chatglm_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_chatglm_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 44914721c40e..563084ed0f09 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -183,6 +183,16 @@ def run_gpt2_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) @clear_cache_before_run() def run_gpt2_3d_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c9d5d3d08305..a60150e3cd72 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -185,6 +185,16 @@ def run_llama_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8c0432b37425..25b1eefc6016 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -174,6 +174,16 @@ def run_opt_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_opt_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 29367031e820..768cae0a6734 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -170,6 +170,16 @@ def run_t5_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_t5_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 2980c6eeafba..15db63bfd9da 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -176,6 +176,15 @@ def run_vit_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, ]) def run_vit_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index a55753018300..d0c04c98f80a 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if test_config['precision'] == 'fp32': - atol, rtol = 5e-4, 5e-4 + atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -195,6 +195,15 @@ def run_whisper_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, ]) def run_whisper_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') From 2c787d7f47f7aa55c27877a66f79e4226d16b92a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 09:57:18 +0800 Subject: [PATCH 23/33] [shardformer] fix submodule replacement bug when enabling pp (#4544) --- colossalai/shardformer/shard/sharder.py | 25 ++++++++++--------- ...st_hybrid_parallel_plugin_checkpoint_io.py | 2 ++ .../test_model/test_shard_chatglm2.py | 2 ++ .../test_model/test_shard_gpt2.py | 2 ++ .../test_model/test_shard_opt.py | 2 ++ 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 0ed745a1fc4a..9ed384266a80 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -92,22 +92,21 @@ def _recursive_replace_layer( param_replacement (List[Callable]): The function list to get parameter shard information in policy method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ - # released layers are not shardable - can_replace_param_or_layer = include is None or module in include if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): if attr_replacement is not None: self._replace_attr(module, attr_replacement) - if param_replacement is not None and can_replace_param_or_layer: + if param_replacement is not None and (include is None or module in include): self._replace_param(module, param_replacement) if method_replacement is not None: self._replace_method(module, method_replacement) - if sub_module_replacement is not None and can_replace_param_or_layer: - self._replace_sub_module(module, sub_module_replacement) + if sub_module_replacement is not None: + self._replace_sub_module(module, sub_module_replacement, include) for name, child in module.named_children(): self._recursive_replace_layer(child, @@ -154,18 +153,17 @@ def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Calla bound_method = MethodType(new_method, module) setattr(module, method_name, bound_method) - def _replace_sub_module( - self, - org_layer: nn.Module, - sub_module_replacement: List[SubModuleReplacementDescription], - ) -> None: + def _replace_sub_module(self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: org_layer (torch.nn.Module): The origin layer object to shard sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list - + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ for description in sub_module_replacement: suffix = description.suffix @@ -174,9 +172,12 @@ def _replace_sub_module( assert target_module is not None, 'target_module should not be None' - # TODO: support different parallel mode native_sub_module = getattr_(org_layer, suffix, ignore=True) + # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled. + if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): + continue + assert not isinstance(native_sub_module, target_module), \ f"The module with suffix {suffix} has been replaced, please check the policy" diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index ea0922ef5dec..67d73c31f6e0 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -7,6 +7,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( check_state_dict_equal, @@ -100,6 +101,7 @@ def _criterion(outputs, inputs): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + Randomizer.reset_index() clear_layout_converter() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 05ca05dea4d6..48f651c727f4 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -4,6 +4,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -105,6 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 563084ed0f09..115a1bd79d41 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -4,6 +4,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -97,6 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25b1eefc6016..3e74859ad1a8 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -107,6 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) + Randomizer.reset_index() torch.cuda.empty_cache() From c9625dbb6364c10f21828b30bc58e8fbcf22a900 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 14:50:47 +0800 Subject: [PATCH 24/33] [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540) * implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp --- .../booster/plugin/hybrid_parallel_plugin.py | 53 ++- .../hybrid_parallel_checkpoint_io.py | 445 +++++++++++++++-- colossalai/checkpoint_io/utils.py | 449 +++++++++--------- colossalai/zero/gemini/gemini_ddp.py | 6 +- colossalai/zero/gemini/gemini_optimizer.py | 44 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 110 +++-- 6 files changed, 775 insertions(+), 332 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c49b3e1823cd..277843b66568 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,7 +1,7 @@ import random from contextlib import nullcontext from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -110,6 +110,36 @@ def unwrap(self): return module +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + start_index = 0 + for group in optim.param_groups: + + packed_group = {k: v for k, v in group.items() if k != 'params'} + packed_group['params'] = [] + + for param_id, param in enumerate(group['params'], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group['params'].append(param_id) + param_info['param2id'][id(param)] = param_id + param_info['id2param'][param_id] = id(param) + param_info['param2shape'][id(param)] = original_shape + + param_info['param_groups'].append(packed_group) + start_index += len(group['params']) + + return param_info + + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) new_param_groups = [] @@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim) @@ -133,6 +164,7 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, precision: str = 'fp16', initial_scale: float = 2**16, min_scale: float = 1, @@ -142,6 +174,7 @@ def __init__(self, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -155,6 +188,7 @@ def __init__( optimizer: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2., @@ -172,6 +206,7 @@ def __init__( dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -356,6 +391,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, @@ -366,25 +402,33 @@ def configure( optimizer = HybridParallelAMPOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, precision=self.precision, max_norm=self.max_norm, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, - use_pipeline=self.enable_pipeline_parallelism) + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline(self, @@ -461,7 +505,8 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) + self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 56a89bff75ca..c128858b1efe 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -4,7 +4,7 @@ import os from pathlib import Path from shutil import rmtree -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union import torch import torch.distributed as dist @@ -13,29 +13,23 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from colossalai.cluster import ProcessGroupMesh -from colossalai.tensor.d_tensor import ( - is_customized_distributed_tensor, - is_distributed_tensor, - to_global, - to_global_for_customized_distributed_tensor, -) +from colossalai.interface import OptimizerWrapper from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, - calculate_tensor_size, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, is_safetensors_available, load_shard_state_dict, load_state_dict_into_model, + load_states_into_optimizer, save_param_groups, - save_state_dict, save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, ) try: @@ -52,9 +46,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): dp_group (ProcessGroup): Process group along data parallel dimension. pp_group (ProcessGroup): Process group along pipeline parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. """ - def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: + def __init__(self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -65,6 +66,10 @@ def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: Pro self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) + self.use_zero = (zero_stage > 0) + self.verbose = verbose + self.working_to_master_map = None + self.master_to_working_map = None @staticmethod def _model_sharder(model: nn.Module, @@ -81,7 +86,7 @@ def _model_sharder(model: nn.Module, continue # Gather tensor pieces when using tensor parallel. param_ = gather_distributed_param(param, keep_vars=False) - block, block_size = state_dict_sharder.append(prefix + name, param_) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -89,7 +94,7 @@ def _model_sharder(model: nn.Module, for name, buf in model.named_buffers(): if buf is not None and name not in model._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = state_dict_sharder.append(prefix + name, buffer) + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -98,7 +103,7 @@ def _model_sharder(model: nn.Module, if getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = model.get_extra_state() - block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -106,10 +111,44 @@ def _model_sharder(model: nn.Module, yield state_dict_sharder.current_block, state_dict_sharder.current_block_size @staticmethod - def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + def _optimizer_sharder(optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, + size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. - # TODO (Baizhou): Implement sharding feature of optimizer. - pass + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + + for param, state in optimizer.optim.state.items(): + + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info['param2id'][id(working_param)] + original_shape = param_info['param2shape'][id(working_param)] + state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def save_sharded_model(self, model: nn.Module, @@ -148,7 +187,7 @@ def save_sharded_model(self, return # Then collect the sharded parameters & buffers along tp_group. - # Only devices with tp_size == 0 are responsible for model saving. + # Only devices with tp_rank == 0 are responsible for model saving. state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) @@ -165,9 +204,10 @@ def save_sharded_model(self, if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") else: # When pipeline is used, each stage produces its own shard files and index files. @@ -212,9 +252,10 @@ def save_sharded_model(self, final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ @@ -222,7 +263,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri Args: model (nn.Module): The model to be loaded. - index_file_path (str): Path to the index file of checkpointing folder. + checkpoint_index_file (str): Path to the index file of checkpointing folder. strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since params on same device might be stored in different files. """ @@ -263,7 +304,6 @@ def _load(name: str): missing_keys=missing_keys, strict=strict, load_sub_module=True) - del state_dict loaded_file.add(filename) # Load parameters. @@ -271,8 +311,11 @@ def _load(name: str): _load(name) # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: + if buf is not None and name not in non_persistent_buffers: _load(name) # Load extra states. @@ -281,16 +324,236 @@ def _load(name: str): torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: _load(extra_state_key) + # Update master params if mixed-precision training is enabled. + with torch.no_grad(): + if self.working_to_master_map is not None: + for param in model.parameters(): + if (param is None) or (id(param) not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[id(param)] + if self.use_zero: + # master_param is sharded under Zero setting + padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size + if padding_size > 0: + padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padded_param = param.data.view(-1) + sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] + master_param.data.copy_(sharded_param.data) + else: + master_param.data.copy_(param.data) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def save_sharded_optimizer(self, - optimizer: Optimizer, + optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024): - pass + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + master_to_working_map=self.master_to_working_map, + size_per_shard=size_per_shard) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose: + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + + def _get_param_id_from_optimizer_param(param: torch.Tensor, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info['param2id'][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg['params']: + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + id_map[param_id] = param - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): - pass + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory.') + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({'param_groups': updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg['params']: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if self.master_to_working_map is not None: + working_param = self.master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info['param2shape'][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state(state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose: + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): # TODO(Baizhou): support this feature after implementing complete state_dict collection @@ -314,3 +577,121 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + + def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + """ + Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. + This mapping can only be created when mixied precision is used. + The created mappings should be mappings from integer parameter addresses to parameter objects. + + Args: + working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. + """ + self.working_to_master_map = dict() + for k, v in working_to_master_map.items(): + if isinstance(k, torch.Tensor): + self.working_to_master_map[id(k)] = v + elif isinstance(k, int): + self.working_to_master_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + self.master_to_working_map = dict() + for k, v in master_to_working_map.items(): + if isinstance(k, torch.Tensor): + self.master_to_working_map[id(k)] = v + elif isinstance(k, int): + self.master_to_working_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + @staticmethod + def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, + dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, + inplace: bool) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # First gather Zero shards. + if use_zero: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().cpu() + + return state_ + + def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, + original_shape: torch.Size, device: torch.device, + inplace: bool) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d04159c54d5e..0025d07dfc8e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,4 +1,5 @@ # coding=utf-8 +import copy import os import re from collections import abc as container_abcs @@ -8,7 +9,9 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch +import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper @@ -93,24 +96,31 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False -def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: """ - Gather the complete parameter for saving if passed in param is distributed. + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. Args: - param (torch.Tensor): A model parameter, might be d_tensor. - keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. Returns: - torch.Tensor: the complete parameter + Optional[int]: The dimension along which parameter is partitioned. """ - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - return to_global(param_) - elif is_customized_distributed_tensor(param_): - return to_global_for_customized_distributed_tensor(param_) - else: - return param_ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ + f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim # ====================================== @@ -136,7 +146,8 @@ def __init__(self, size_per_shard: int) -> None: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 @@ -153,6 +164,64 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict] self.current_block_size += tensor_size return ret_block, ret_block_size + def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + ret_block = None + ret_block_size = 0 + + # directly return if state is stored as distributed tensor + if isDTensor: + return ret_block, ret_block_size + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[param_id] = state + self.current_block_size += state_size + return ret_block, ret_block_size + + +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor: + """ + Gather the complete parameter for saving if passed in param is distributed under tp setting. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, @@ -198,28 +267,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for key, weight in state_dict.items(): - ret_block = None - ret_block_size = 0 if not is_distributed_tensor(weight): - weight_size = calculate_tensor_size(weight) - - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - current_block[key] = weight - current_block_size += weight_size + block, block_size = state_dict_sharder.append_param(key, weight) - if ret_block != None: - yield ret_block, ret_block_size + if block != None: + yield block, block_size - yield current_block, current_block_size + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: @@ -230,47 +288,147 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. states = state_dict['state'] - - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + block, block_size = state_dict_sharder.append_optim_state(param_id, state) + if block != None: + yield block, block_size - ret_block = None - ret_block_size = 0 + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue +# ====================================== +# Helper functions for saving state dict +# ====================================== - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - if not isDTensor: +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. + + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" + - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. - current_block[param_id] = state - current_block_size += state_size + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' - if ret_block != None: - yield ret_block, ret_block_size - yield current_block, current_block_size +# ======================================== +# Helper functions for loading state dict +# ======================================== def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): @@ -383,17 +541,21 @@ def update_group(group, new_group): return id_map -def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False): r"""Copies states from `state_dict` into an Optimizer object. Args: optimizer(Optimizer): An initialized Optimizer object to be loaded - state_dict(dict): a mapping from tensor index (an integer) + state_dict(dict): A mapping from tensor index (an integer) to its states to be loaded (a mapping from state name to a tensor). - id_map(dict): a mapping from tensor index (an integer) + id_map(dict): A mapping from tensor index (an integer) to its corresponding parameter (a tensor) whose states will be updated. + strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. """ + # Ensure that the keys of state_dict are integers. + state_dict = {int(k): v for k, v in state_dict.items()} + def cast(param, value, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): @@ -420,7 +582,7 @@ def cast(param, value, key=None): if k in id_map: param = id_map[k] new_states[param] = cast(param, v) - else: + elif not strict: new_states[k] = v optimizer.state.update(new_states) @@ -438,165 +600,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): optimizer.defaults.setdefault('differentiable', False) -# ====================================== -# Helper functions for saving state dict -# ====================================== - - -def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: - """ - Save state dict to checkpoint. - - Args: - state_dict (dict): state dict. - checkpoint_file_path (str): path to the checkpoint file. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - if use_safetensors: - assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) - else: - torch.save(state_dict, checkpoint_file_path) - - -def save_param_groups(state_dict: dict, group_file_path: str) -> None: - """ - Save information of param_groups to given file path. - - Args: - state_dict (dict): state dict. - group_file_path (str): path to the group file. - """ - param_groups = state_dict["param_groups"] - torch.save(param_groups, group_file_path) - - -def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: - """ - Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains - only one tensor. - - Args: - tensor (Tensor): tensor to be saved. - index_file (CheckpointIndexFile): path to the checkpoint file. - size_per_shard (int): size per shard in MB. - """ - root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') - - # create directory - output_root_path.mkdir(exist_ok=True) - - # save tensor to this directory - # TODO(YuliangLiu): get index of the tensor shard - # e.g. index = - index = 0 - - # save tensor to file - ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) - ckpt_file_path = output_root_path.joinpath(ckpt_file_name) - - # dtensor ckpt file always contains only one tensor - state_dict = {name: tensor} - save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) - - # update the weight map - # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) - index_file.append_weight_map(name, ckpt_file_name_in_weight_map) - - -def get_checkpoint_file_suffix(use_safetensors: bool) -> str: - """ - Get checkpoint file suffix. - - Args: - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: checkpoint file suffix. - """ - if use_safetensors: - return '.safetensors' - else: - return '.bin' - - -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: - """ - Generate checkpoint shard file name. - - Args: - index (int): index of the shard. - total_number (int): total number of shards. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - prefix (str): prefix of the shard file name. Default: None. - - Returns: - str: checkpoint shard file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - - if prefix is None: - return f"{index:05d}-of-{total_number:05d}.{suffix}" - else: - return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" - - -def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: - """ - Generate dtensor file name. - - Args: - param_name (str): name of the distributed parameter. - index (int): index of the shard. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: dtensor file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' - - -def save_state_dict_as_shard( - state_dict: dict, - checkpoint_path: str, - index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None, -) -> None: - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (str): path to the checkpoint file. - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - # generate the shard name - shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) - shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() - - # save the shard - save_state_dict(state_dict, str(shard_file_path), use_safetensors) - - -# ======================================== -# Helper functions for loading state dict -# ======================================== - - def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: """ Check whether the checkpoint has an index file. diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5aff91f03153..1c19071feb67 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -679,7 +679,7 @@ def state_dict_shard(self, gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block, block_size = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -690,7 +690,7 @@ def state_dict_shard(self, for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = sharder.append(prefix + name, buffer) + block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size # save extra states @@ -698,7 +698,7 @@ def state_dict_shard(self, if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block, block_size = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index a2085323f83e..58b0f33ab189 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,7 +10,7 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.tensor.d_tensor import is_distributed_tensor @@ -691,49 +691,17 @@ def state_shard(self, Iterator[OrderedDict]: A generator of state dict shard of optimizer states. """ - current_block = {} - current_block_size = 0 - + sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) - ret_block = None - ret_block_size = 0 - - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue - - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - - if not isDTensor: - - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - - current_block[param_id] = state - current_block_size += state_size - - if ret_block != None: - yield ret_block, ret_block_size + block, block_size = sharder.append_optim_state(param_id, state) + if block is not None: + yield block, block_size - yield current_block, current_block_size + yield sharder.current_block, sharder.current_block_size class GeminiAdamOptimizer(ZeroOptimizer): diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 67d73c31f6e0..e43908e0c651 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( + assert_close_loose, check_state_dict_equal, clear_cache_before_run, parameterize, @@ -19,34 +20,34 @@ from tests.kit.model_zoo import model_zoo +# TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() @parameterize('shard', [True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', -}, { 'tp_size': 4, 'pp_size': 1, 'precision': 'fp32', }, { 'tp_size': 2, - 'pp_size': 1, - 'precision': 'fp32', + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 }, { 'tp_size': 2, 'pp_size': 1, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): @@ -61,46 +62,91 @@ def _criterion(outputs, inputs): loss = criterion(outputs) return loss + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + model = model_fn().cuda() optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_model = model_fn().cuda() - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - data_iter = iter([data]) - output = booster.execute_pipeline(data_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) else: - data = {k: v.cuda() for k, v in data.items()} - output = model(**data) + output = model(**_preprocess_data(data)) loss = criterion(output) optimizer.backward(loss) optimizer.step() with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" - # optimizer_ckpt_path = f"{tempdir}/optimizer" + optimizer_ckpt_path = f"{tempdir}/optimizer" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) + dist.barrier() + + # Check whether the loaded model & optimizer works smoothly. + model.train() + new_model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False) + else: + old_model_loss = criterion(model(**_preprocess_data(data))) + optimizer.backward(old_model_loss) + new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_optimizer.backward(new_model_loss) + + optimizer.step() + new_optimizer.step() + + # Check updated weights. + stage_manager = booster.plugin.stage_manager + + if stage_manager is None or stage_manager.is_first_stage(): + assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) + assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, + new_model.unwrap().h[0].mlp.c_fc.weight.data, + atol=5e-3, + rtol=5e-3) + dist.barrier() Randomizer.reset_index() clear_layout_converter() From 38ccb8b1a321fa70926236a22cfd7911a993b53e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 1 Sep 2023 17:40:01 +0800 Subject: [PATCH 25/33] [shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575) * hybrid plugin support huggingface from_pretrained * add huggingface compatibility tests * add folder cleaning * fix bugs --- .github/workflows/build_on_pr.yml | 2 +- .../booster/plugin/hybrid_parallel_plugin.py | 4 +- .../hybrid_parallel_checkpoint_io.py | 19 ++- colossalai/checkpoint_io/utils.py | 81 ++++++++++- .../test_hybrid_huggingface_compatibility.py | 129 ++++++++++++++++++ 5 files changed, 218 insertions(+), 17 deletions(-) create mode 100644 tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 4c7e08e5799e..3f91dc33a660 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 277843b66568..eced4fc1a16b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer): def init_pipeline_optimizer(optim: Optimizer, model: Module): - params = set(model.parameters()) + model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: - params = [p for p in group['params'] if p in params] + params = [p for p in group['params'] if p in model_params] new_param_groups.append({**group, 'params': params}) optim.__setstate__({'param_groups': new_param_groups}) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index c128858b1efe..fef5b0d16d60 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -26,6 +26,7 @@ load_shard_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict_shards, search_tp_partition_dim, @@ -204,6 +205,7 @@ def save_sharded_model(self, if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) if self.verbose: logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -219,9 +221,9 @@ def save_sharded_model(self, Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, @@ -229,7 +231,8 @@ def save_sharded_model(self, index_file=index_file, base_filename=weights_name, is_master=control_saving, - use_safetensors=use_safetensors) + use_safetensors=use_safetensors, + use_pp_format=True) if control_saving: assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) @@ -251,6 +254,7 @@ def save_sharded_model(self, final_index_file.append_weight_map(weight, weight_filename) final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) if self.verbose: logging.info(f"The model is split into checkpoint shards. " @@ -423,15 +427,16 @@ def save_sharded_optimizer(self, Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=control_saving) + is_master=control_saving, + use_pp_format=True) if control_saving: assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0025d07dfc8e..0300e62653eb 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -9,12 +9,12 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ProcessGroup from torch.optim import Optimizer +from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype +from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, @@ -228,7 +228,8 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] index_file: "CheckpointIndexFile", base_filename: str, is_master: bool, - use_safetensors: bool = False) -> int: + use_safetensors: bool = False, + use_pp_format: bool = False) -> int: ''' Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Args: @@ -236,14 +237,16 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] checkpoint (str): The path of checkpoint directory as string. index_file (CheckpointIndexFile): The index file object to be updated. base_filename (str): Decides the prefix of filenames of shards. - is_master (bool): Whether current rank is master. - use_safetensors (bool): Whether to use safetensors to save checkpoint. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. Returns: int: the total size of shards ''' total_size = 0 + shard_filenames = [] for idx, shard_pair in enumerate(sharded_state_dict): shard, current_size = shard_pair if not is_master: @@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + shard_filenames.append(shard_file) del shard + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + return total_size @@ -335,6 +342,66 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: torch.save(param_groups, group_file_path) +def clean_folder(checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False): + """ + Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. + + Args: + checkpoint_path (str): Path to the checkpoint directory. + weights_name (str): Decides the prefix of filenames of weight shards. + shard_filenames (List[str]): The list of saved shard filenames which should not be removed. + is_master (bool, optional): Whether current rank is main process. Defaults to True. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + """ + if is_master: + for filename in os.listdir(checkpoint_path): + full_filename = os.path.join(checkpoint_path, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + if not use_pp_format: + reg = re.compile(r"(.*?)-\d{5}") + else: + # When this checkpoint is created by pipeline parallel process, the pattern is a little different. + reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") + if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) + and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): + os.remove(full_filename) + + +def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True): + """ + Save config.json/generation_config.json if model is a Huggingface pretrained model. + This method can only be called when a model is saved in a sharded way. + + Args: + model (nn.Module): The model whose config should be saved if it's a huggingface model. + checkpoint_path (str): Path to the checkpoint directory. + is_master (bool): Whether current rank is main process. + """ + if not isinstance(model, PreTrainedModel): + return + + model = unwrap_huggingface_model(model) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + dtype = get_parameter_dtype(model) + model.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model.config.architectures = [model.__class__.__name__] + + # Save the config + if is_master: + model.config.save_pretrained(checkpoint_path) + if model.can_generate(): + model.generation_config.save_pretrained(checkpoint_path) + + def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: """ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains @@ -709,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int): get shard file name """ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") return shard_file diff --git a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py new file mode 100644 index 000000000000..df907605d869 --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py @@ -0,0 +1,129 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +def exam_from_pretrained(model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + test_config, + shard=True, + size_per_shard=32): + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model = model_fn() + optimizer = Adam((model.parameters()), lr=0.001) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + output = model(**_preprocess_data(data)) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path) + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@clear_cache_before_run() +@parameterize('test_config', [{ + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 +}]) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + clear_layout_converter() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_test() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) From 508ca36fe37a8d9434647d224757e06833ed6557 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 1 Sep 2023 21:45:14 +0800 Subject: [PATCH 26/33] [pipeline] 1f1b schedule receive microbatch size (#4589) --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +++++- colossalai/pipeline/schedule/one_f_one_b.py | 27 +++++++++++++++---- .../test_schedule/test_oneF_oneB.py | 2 +- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eced4fc1a16b..c83e51b26d28 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. @@ -278,6 +281,7 @@ def __init__(self, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -324,7 +328,9 @@ def __init__(self, assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 0058873c21ba..11b2655a22c9 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -17,14 +17,26 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: + def __init__(self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None) -> None: + """1F1B pipeline schedule. + + Args: + stage_manager (PipelineStageManager): Pipeline stage manager + num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. + microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. + """ super().__init__(stage_manager) + assert num_microbatches is not None or microbatch_size is not None, \ + "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches + self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.microbatch_size: Optional[int] = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -39,9 +51,14 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + if self.num_microbatches is not None: + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + else: + assert self.batch_size % self.microbatch_size == 0, \ + "Batch size should divided by the microbatch size" + self.num_microbatches = self.batch_size // self.microbatch_size def load_micro_batch(self) -> Any: """Load a micro batch from the current batch. diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 542116a1da75..d31eafd70e1a 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -61,7 +61,7 @@ def examine_pp(): DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager) + schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS) for idx, (_, sub_model) in enumerate(pp_model.named_children()): if idx % (world_size) == local_rank: From 24c076879558133d66ffcb6111f9bccaa23f6017 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:52:23 +0800 Subject: [PATCH 27/33] [shardformer] Pytree fix (#4533) * pytree test * test bert * test bert * test bert * revise * add register * add register --- colossalai/pipeline/schedule/_utils.py | 62 +++++++++++++++++-- colossalai/pipeline/schedule/one_f_one_b.py | 19 ++++-- colossalai/shardformer/policies/chatglm2.py | 5 ++ tests/test_shardformer/test_model/_utils.py | 11 +--- .../test_model/test_shard_bert.py | 1 + 5 files changed, 81 insertions(+), 17 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 5cd934b76822..583558551b3c 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -1,9 +1,59 @@ -from typing import Any, List, Optional +from collections import OrderedDict +from typing import Any, List, Optional, Tuple import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import ( + SUPPORTED_NODES, + LeafSpec, + TreeSpec, + _is_leaf, + _register_pytree_node, + tree_flatten, + tree_map, + tree_unflatten, +) + + +# this register are for torch under version 1.13.1, maybe removed in the future +def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]: + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]': + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) + + +def tree_map_hf(fn: Any, pytree: Any): + flat_args, spec = tree_flatten_hf(pytree) + return tree_unflatten([fn(i) for i in flat_args], spec) + + +# use this flatten function to handle the ModelingOutput Class instance. +def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values an a TreeSpec that can be used + to reconstruct the pytree. + """ + if isinstance(pytree, OrderedDict): + node_type = OrderedDict + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result: List[Any] = [] + children_specs: List['TreeSpec'] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten_hf(child) + result += flat + children_specs.append(child_spec) + return result, TreeSpec(node_type, context, children_specs) + else: + result, tree_spec = tree_flatten(pytree) + return result, tree_spec def to_device(x: Any, device: Optional[torch.device] = None) -> Any: @@ -104,7 +154,7 @@ def detach(x: Any) -> Any: return x -def merge_batch(data: List[Any]) -> Any: +def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. Args: @@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any: flattened_data = [] tree_spec = None for d in data: - elems, tree_spec = tree_flatten(d) + # elems should be an instance of OrderedDict + elems, tree_spec = tree_flatten_hf(d) flattened_data.append(elems) merged_data = [] + for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs merged_data.append(None) else: - merged_data.append(torch.cat(elem_batch, dim=0)) + merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 11b2655a22c9..ec53a67716c4 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,12 +6,21 @@ from torch.nn import Module from torch.utils._pytree import tree_map -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.cuda import get_current_device -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from ._utils import ( + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + retain_grad, + to_device, + tree_map_hf, +) from .base import PipelineSchedule @@ -154,7 +163,7 @@ def forward_step(self, if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: - outputs.append(tree_map(detach, output_obj)) + outputs.append(tree_map_hf(detach, output_obj)) return loss else: return output_obj @@ -302,5 +311,7 @@ def forward_backward_step(self, self.send_backward(input_obj_grad) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0)) return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 5bcbc2acc28e..44898847056a 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -41,6 +41,11 @@ def preprocess(self): new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + if self.pipeline_stage_manager is not None: + # the batch_size_dim is bounded to Model + bsz_dim = 1 + setattr(self.model, 'batch_size_dim', bsz_dim) + return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 72bb2b025ba4..f77bf7495808 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state - if stage_manager is None: - sharded_hidden_state = sharded_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(): - pipeline_output = sharded_output['outputs'] - if isinstance(pipeline_output, List): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim) - else: - sharded_hidden_state = pipeline_output.last_hidden_state + sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] + else: + sharded_hidden_state = sharded_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0855e2248710..c779e417052b 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -179,6 +179,7 @@ def run_bert_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() From 0a94fcd3514a6f7d4f287bba614fda3fb12c8802 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 4 Sep 2023 21:46:29 +0800 Subject: [PATCH 28/33] [shardformer] update bert finetune example with HybridParallelPlugin (#4584) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * [shardformer] fix opt test hanging * fix * test * test * [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py * [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516) * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom * [shardformer] fix emerged bugs after updating transformers (#4526) * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code * [shardformer] support pp+tp+zero1 tests (#4531) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] fix submodule replacement bug when enabling pp (#4544) * [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540) * implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * rebase feature/shardformer * update pipeline * [shardformer] fix * [shardformer] fix * [shardformer] bert finetune fix * [shardformer] add all_reduce operation to loss add all_reduce operation to loss * [shardformer] make compatible with pytree. make compatible with pytree. * [shardformer] disable tp disable tp * [shardformer] add 3d plugin to ci test * [shardformer] update num_microbatches to None * [shardformer] update microbatchsize * [shardformer] update assert * update scheduler * update scheduler --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Baizhou Zhang --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/pipeline/schedule/one_f_one_b.py | 3 +- examples/language/bert/finetune.py | 163 ++++++++++++++---- examples/language/bert/test_ci.sh | 2 +- 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c83e51b26d28..8ad9b795692a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -325,7 +325,7 @@ def __init__(self, self.schedule = None assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ec53a67716c4..5db1c7f30d7f 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -46,6 +46,7 @@ def __init__(self, self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None + self._use_microbatch_size = num_microbatches is None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -60,7 +61,7 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - if self.num_microbatches is not None: + if not self._use_microbatch_size: assert self.batch_size % self.num_microbatches == 0, \ "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index b209ffde85a4..b9a3d57536e4 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,12 +1,14 @@ import argparse -from typing import List, Union +from contextlib import nullcontext +from typing import Callable, List, Union import evaluate import torch import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Optimizer +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( @@ -18,8 +20,9 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -32,14 +35,26 @@ WEIGHT_DECAY = 0.01 WARMUP_FRACTION = 0.1 +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + def move_to_cuda(batch): return {k: v.cuda() for k, v in batch.items()} @torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, - eval_splits: List[str], coordinator: DistCoordinator): +def evaluate_model( + model: nn.Module, + optimizer, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -47,23 +62,66 @@ def evaluate_subset(dataloader: DataLoader): accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - labels = batch["labels"] - - metric.add_batch(predictions=preds, references=labels) + batch_size = batch["input_ids"].shape[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + #TODO pass dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + + if booster.plugin.stage_manager.is_last_stage(): + val_loss = outputs["loss"] + + logits = outputs["outputs"]["logits"] + + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast(preds, src=current_rank, group=pp_group) + dist.broadcast(val_loss, src=current_rank, group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + val_loss = torch.empty((1,), device=get_current_device()) + preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device()) + + dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group) + dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group) + + accum_loss.add_(val_loss) + metric.add_batch(predictions=preds, references=labels) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) - if coordinator.is_master(): + if coordinator.is_master() and results is not None: results['loss'] = accum_loss.item() / coordinator.world_size + return results if isinstance(test_dataloader, DataLoader): @@ -77,25 +135,43 @@ def evaluate_subset(dataloader: DataLoader): return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + is_pp_last_stage = hasattr( + booster.plugin, + "stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage() + with tqdm(train_dataloader, + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) - outputs = model(**batch) - loss = outputs[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + #TODO pass train_dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if booster.plugin.stage_manager.is_last_stage(): + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + else: + outputs = model(**batch) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss.item()}) - # Backward and optimize - booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() lr_scheduler.step() - # Print log info - pbar.set_postfix({'loss': loss.item()}) - def main(): # ============================== @@ -107,7 +183,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], help="plugin to use") parser.add_argument( "--model_type", @@ -116,6 +192,7 @@ def main(): help="bert or albert", ) parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() if args.model_type == 'bert': @@ -145,6 +222,17 @@ def main(): plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) booster = Booster(plugin=plugin, **booster_kwargs) @@ -165,8 +253,9 @@ def main(): # bert pretrained model cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: @@ -196,19 +285,27 @@ def main(): num_training_steps=total_steps, ) + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) # ============================== # Train model # ============================== for epoch in range(NUM_EPOCHS): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, + data_builder.eval_splits, booster, coordinator) if coordinator.is_master(): print(results) diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index 7fc6daabb2f3..394ff831b855 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -3,6 +3,6 @@ set -xe pip install -r requirements.txt -for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" done From e79b1e80e25a14c345a2702995b38e418d26c12a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 4 Sep 2023 23:25:01 +0800 Subject: [PATCH 29/33] [checkpointio] support huggingface from_pretrained for all plugins (#4606) --- colossalai/booster/plugin/gemini_plugin.py | 2 + .../checkpoint_io/general_checkpoint_io.py | 2 + .../test_hybrid_huggingface_compatibility.py | 129 ------------------ .../test_plugins_huggingface_compatibility.py | 83 +++++++++++ 4 files changed, 87 insertions(+), 129 deletions(-) delete mode 100644 tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py create mode 100644 tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 0f5ba6e9a6da..8489a8f29686 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -18,6 +18,7 @@ get_optimizer_base_filenames, get_shard_filename, load_shard_state_dict, + save_config_file, save_state_dict, save_state_dict_shards, ) @@ -111,6 +112,7 @@ def save_sharded_model(self, if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model.module, checkpoint_path) logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 83e4bdcc863b..09362d145af2 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -23,6 +23,7 @@ load_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict, save_state_dict_shards, @@ -185,6 +186,7 @@ def save_sharded_model(self, index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint_path, is_master=True) logging.info(f"The model is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py deleted file mode 100644 index df907605d869..000000000000 --- a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from torch.optim import Adam -from utils import shared_tempdir - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import ( - check_state_dict_equal, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo - - -def exam_from_pretrained(model_fn, - data_gen_fn, - output_transform_fn, - loss_fn, - test_config, - shard=True, - size_per_shard=32): - - def _criterion(outputs, inputs): - outputs = output_transform_fn(outputs) - loss = criterion(outputs) - return loss - - def _preprocess_data(data): - if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - return iter([data]) - else: - return {k: v.cuda() for k, v in data.items()} - - model = model_fn() - optimizer = Adam((model.parameters()), lr=0.001) - criterion = loss_fn - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - data = data_gen_fn() - model.train() - if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) - else: - output = model(**_preprocess_data(data)) - loss = criterion(output) - optimizer.backward(loss) - - optimizer.step() - - with shared_tempdir() as tempdir: - - model_ckpt_path = f"{tempdir}/model" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - dist.barrier() - - new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path) - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - - check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) - - Randomizer.reset_index() - torch.cuda.empty_cache() - - -@clear_cache_before_run() -@parameterize('test_config', [{ - 'tp_size': 4, - 'pp_size': 1, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 2, - 'pp_size': 1, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) -def run_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - clear_layout_converter() - torch.cuda.empty_cache() - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_test() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_huggingface_compatibility(world_size): - spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py new file mode 100644 index 000000000000..3f3b0392ab5c --- /dev/null +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -0,0 +1,83 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('model_name', ['transformers_gpt']) +@parameterize('plugin_type', ['ddp', 'zero', 'gemini']) +def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + + if plugin_type == 'ddp': + plugin = TorchDDPPlugin() + elif plugin_type == 'zero': + plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) + elif plugin_type == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32) + else: + raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") + + booster = Booster(plugin=plugin) + + model = model_fn().cuda() + model_huggingface_cls = model.__class__ + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_huggingface_cls.from_pretrained(model_ckpt_path) + new_model = new_model.cuda() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + if plugin_type == 'gemini': + check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), + new_model.unwrap().state_dict(only_rank_0=False), False) + else: + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + dist.barrier() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_from_pretrained() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) From 86d22581e42b350fbe9c5a1f7bc45f7487620214 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:52:23 +0800 Subject: [PATCH 30/33] [shardformer] Add overlap optional for HybridParallelPlugin (#4615) * add optional overlap for plugin * remove fixed todo --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 +++- colossalai/shardformer/layer/_operation.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ad9b795692a..d33e3485c39c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -280,6 +280,7 @@ def __init__(self, enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -341,7 +342,8 @@ def __init__(self, enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism) + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f45ccc64bae5..45b305733813 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -180,7 +180,6 @@ def backward(ctx, grad_output): overlap = ctx.overlap if not overlap: - # TODO: overlap SP input with gradient computation input_parallel = _gather(input_, dim, process_group) total_input = input_parallel @@ -191,7 +190,6 @@ def backward(ctx, grad_output): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - # TODO: overlap SP input with gradient computation if ctx.async_grad_reduce_scatter: # Asynchronous reduce-scatter input_list = [ From ec0866804c3f028f73d4e4d0bc1f3309362c4e89 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 5 Sep 2023 13:14:41 +0800 Subject: [PATCH 31/33] [shardformer] update shardformer readme (#4617) [shardformer] update shardformer readme [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 11 ++++++----- examples/language/bert/README.md | 14 ++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 7dc15f0a0635..2e48a79dc1d7 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. -| accuracy | f1 | loss | GPU number | model shard | + +| accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.82594 | 0.87441 | 0.09913 | 4 | True | -| 0.81884 | 0.87299 | 0.10120 | 2 | True | -| 0.81855 | 0.87124 | 0.10357 | 1 | False | +| 0.84589 | 0.88613 | 0.43414 | 4 | True | +| 0.83594 | 0.88064 | 0.43298 | 1 | False | + Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/examples/language/bert/README.md b/examples/language/bert/README.md index da38e8375bf0..6601edb7960e 100644 --- a/examples/language/bert/README.md +++ b/examples/language/bert/README.md @@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be bash test_ci.sh ``` -### Results on 2-GPU +### Bert-Finetune Results + +| Plugin | Accuracy | F1-score | GPU number | +| -------------- | -------- | -------- | -------- | +| torch_ddp | 84.4% | 88.6% | 2 | +| torch_ddp_fp16 | 84.7% | 88.8% | 2 | +| gemini | 84.0% | 88.4% | 2 | +| hybrid_parallel | 84.5% | 88.6% | 4 | -| Plugin | Accuracy | F1-score | -| -------------- | -------- | -------- | -| torch_ddp | 84.4% | 88.6% | -| torch_ddp_fp16 | 84.7% | 88.8% | -| gemini | 84.0% | 88.4% | ## Benchmark ``` From e71d2452936372eaca5d300d43a11c35958fc011 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Sep 2023 14:21:31 +0800 Subject: [PATCH 32/33] [test] ignore gpt2 shardformer test (#4619) --- tests/test_shardformer/test_model/test_shard_gpt2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index a4def9e505d8..24f5137ae929 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -102,6 +102,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() +@pytest.mark.skip(reason="This test will hang in CI") @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, From bd18678478e5ecd18a9fa8a70eedea6f1fcdd036 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Sep 2023 16:02:23 +0800 Subject: [PATCH 33/33] [test] fix gemini checkpoint and gpt test (#4620) --- .../test_plugins_huggingface_compatibility.py | 2 +- tests/test_shardformer/test_model/test_shard_gpt2.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index 3f3b0392ab5c..bd041a5e2fd3 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per elif plugin_type == 'zero': plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) elif plugin_type == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32) + plugin = GeminiPlugin(precision="fp16", initial_scale=32) else: raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 24f5137ae929..768063e537c7 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -102,7 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@pytest.mark.skip(reason="This test will hang in CI") @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -220,7 +219,7 @@ def check_gpt2_3d(rank, world_size, port): run_gpt2_3d_test() - +@pytest.mark.skip(reason="This test will hang in CI") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()