diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index bb36854bd772..d59b68ce4480 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -143,6 +143,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = Linear1D_Col(in_features=in_features, out_features=out_features, bias=bias, @@ -293,6 +301,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = Linear1D_Row(in_features=in_features, out_features=out_features, bias=bias, diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 42417f8bcc43..df942d43ee2d 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -265,6 +265,14 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, out_features=out_features, bias=bias, @@ -420,6 +428,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, out_features=out_features, bias=bias, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index bdf141816737..9afdfff4d71d 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,7 +1,500 @@ -from typing import Optional, Tuple +import random +from typing import List, Optional, Tuple, Union import torch -from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.opt.modeling_opt import ( + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class OPTPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of OPT models + under pipeline setting. + ''' + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + from transformers.models.opt.modeling_opt import _make_causal_mask + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + _dtype, + device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, + tgt_len=input_shape[-1]).to(device) + combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + + combined_attention_mask) + + return combined_attention_mask + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def opt_model_forward( + self: OPTModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + ''' + This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward + ''' + + from transformers.modeling_outputs import BaseModelOutputWithPast + from transformers.utils import logging + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + decoder = self.decoder + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + batch_size, seq_length = input_shape + + if inputs_embeds is None: + inputs_embeds = decoder.embed_tokens(input_ids) + + if decoder.project_in is not None: + inputs_embeds = decoder.project_in(inputs_embeds) + device = input_ids.device if input_ids is not None else inputs_embeds.device + _dtype = inputs_embeds.dtype + + else: + if hidden_states is None: + raise ValueError("hidden_states shouln't be None for intermediate stages.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + _dtype = hidden_states.dtype + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + # embed positions + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)") + + causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, + device, past_key_values_length) + + if stage_manager.is_first_stage(): + pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) + hidden_states = inputs_embeds + pos_embeds + + if decoder.gradient_checkpointing and decoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(decoder.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + torch.cuda.set_device(device) + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = decoder.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if decoder.training and (dropout_probability < decoder.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if decoder.gradient_checkpointing and decoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + if decoder.final_layer_norm is not None: + hidden_states = decoder.final_layer_norm(hidden_states) + if decoder.project_out is not None: + hidden_states = decoder.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_causal_lm_forward( + self: OPTForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward. + Please refer to original code of transformers for more details. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + logits = self.lm_head(outputs[0]).contiguous() + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_sequence_classification_forward( + self: OPTForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward. + Please refer to original code of transformers for more details. + """ + + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_question_answering_forward( + self: OPTForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} def get_opt_flash_attention_forward(): diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 2a041af19be8..eec339c02872 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -122,6 +122,12 @@ class PolicyLocation: PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"), "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"), + + # ChatGLM + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": + PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": + PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 88ecd8565091..ba6036bd0658 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,32 +1,14 @@ -import logging -import random from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List -import torch import torch.nn as nn from torch import Tensor, nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, -) -from transformers.models.opt.modeling_opt import ( - OPTForCausalLM, - OPTForQuestionAnswering, - OPTForSequenceClassification, - OPTModel, -) - -from colossalai.pipeline.stage_manager import PipelineStageManager + from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .._utils import getattr_, setattr_ +from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func -from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward +from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -228,6 +210,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] + return [] def postprocess(self): if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: @@ -295,594 +278,3 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: "no shared params in OPTForSequenceClassification" return [] - - -class OPTPipelineForwards: - ''' - This class serves as a micro library for forward function substitution of OPT models - under pipeline setting. - ''' - - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - from transformers.models.opt.modeling_opt import _make_causal_mask - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - _dtype, - device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, - tgt_len=input_shape[-1]).to(device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) - - return combined_attention_mask - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - @staticmethod - def opt_model_forward( - self: OPTModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - ''' - This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward - ''' - - from transformers.modeling_outputs import BaseModelOutputWithPast - from transformers.utils import logging - logger = logging.get_logger(__name__) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - decoder = self.decoder - if stage_manager.is_first_stage(): - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - batch_size, seq_length = input_shape - - if inputs_embeds is None: - inputs_embeds = decoder.embed_tokens(input_ids) - - if decoder.project_in is not None: - inputs_embeds = decoder.project_in(inputs_embeds) - device = input_ids.device if input_ids is not None else inputs_embeds.device - _dtype = inputs_embeds.dtype - - else: - if hidden_states is None: - raise ValueError("hidden_states shouln't be None for intermediate stages.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device - _dtype = hidden_states.dtype - - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - # embed positions - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)") - - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, - device, past_key_values_length) - - if stage_manager.is_first_stage(): - pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) - hidden_states = inputs_embeds + pos_embeds - - if decoder.gradient_checkpointing and decoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') - past_key_values = None - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(decoder.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") - - start_idx, end_idx = stage_index[0], stage_index[1] - - torch.cuda.set_device(device) - - for idx in range(start_idx, end_idx): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - decoder_layer = decoder.layers[idx] - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if decoder.training and (dropout_probability < decoder.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if decoder.gradient_checkpointing and decoder.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - causal_attention_mask, - head_mask[idx] if head_mask is not None else None, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - if decoder.final_layer_norm is not None: - hidden_states = decoder.final_layer_norm(hidden_states) - if decoder.project_out is not None: - hidden_states = decoder.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - else: - return {'hidden_states': hidden_states} - - @staticmethod - def opt_for_causal_lm_forward( - self: OPTForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional - tensors are only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, OPTForCausalLM - - >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - from transformers.modeling_outputs import CausalLMOutputWithPast - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = OPTPipelineForwards.opt_model_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - if stage_manager.is_last_stage(): - logits = self.lm_head(outputs[0]).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - @staticmethod - def opt_for_sequence_classification_forward( - self: OPTForSequenceClassification, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - from transformers.modeling_outputs import SequenceClassifierOutputWithPast - from transformers.utils import logging - logger = logging.get_logger(__name__) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0] - - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - logger.warning( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - @staticmethod - def opt_for_question_answering_forward( - self: OPTForQuestionAnswering, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, OPTForQuestionAnswering - >>> import torch - - >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - - >>> # note: we are loading a OPTForQuestionAnswering from the hub here, - >>> # so the head will be randomly initialized, hence the predictions will be random - >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") - - >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - - >>> inputs = tokenizer(question, text, return_tensors="pt") - >>> with torch.no_grad(): - ... outputs = model(**inputs) - - >>> answer_start_index = outputs.start_logits.argmax() - >>> answer_end_index = outputs.end_logits.argmax() - - >>> answer_offset = len(tokenizer(question)[0]) - - >>> predict_answer_tokens = inputs.input_ids[ - ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 - ... ] - >>> predicted = tokenizer.decode(predict_answer_tokens) - >>> predicted - ' a nice puppet' - ```""" - from transformers.modeling_outputs import QuestionAnsweringModelOutput - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + transformer_outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 177edbef8935..2d9c882089cb 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -53,7 +53,8 @@ def data_gen_for_question_answering(): # inputs = tokenizer(question, text, return_tensors="pt") input_ids = torch.tensor( - [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64) + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], + dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64) @@ -73,12 +74,13 @@ def data_gen_for_question_answering(): loss_fn_for_classification = lambda x: x.loss loss_fn_for_question_answering = lambda x: x.loss -config = transformers.BloomConfig(n_layer=1, +config = transformers.BloomConfig(n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, - hidden_size=64) + hidden_size=64, + pad_token_id=50256) # register the following models model_zoo.register(name='transformers_bloom', diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 90bb70bc7f79..c6473ee2a025 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -17,14 +17,24 @@ def data_gen(): return dict(input_ids=input_ids, attention_mask=attention_mask) +def data_gen_for_conditional_generation(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + # define output transform function output_transform_fn = lambda x: x # define loss function -loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum() -loss_fn = lambda x: x.logits.sum() +loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, + torch.ones_like(x.last_hidden_state)) +loss_fn = lambda x: x.loss -config = ChatGLMConfig(num_layers=1, +config = ChatGLMConfig(num_layers=2, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, @@ -33,7 +43,6 @@ def data_gen(): use_cache=True, torch_dtype=torch.float32) - model_zoo.register(name='transformers_chatglm', model_fn=lambda: ChatGLMModel(config, empty_init=False), data_gen_fn=data_gen, @@ -43,7 +52,7 @@ def data_gen(): model_zoo.register(name="transformers_chatglm_for_conditional_generation", model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index 93a8d6c615d7..a84b8d31c284 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -7,11 +7,7 @@ # Register single-sentence VIT # =============================== -config = transformers.ViTConfig( - num_hidden_layers=4, - # hidden_size=128, - # intermediate_size=256, - num_attention_heads=4) +config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) # define data gen function diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c51df07f6c11..921af2a8b1d0 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c if 'use_lazy_init' in test_config: use_lazy_init = test_config.pop('use_lazy_init') - if use_lazy_init: - ctx = LazyInitContext() - else: - ctx = nullcontext() - - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - + ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: - org_model = model_fn().cuda() + org_model = model_fn() sharded_model = copy.deepcopy(org_model) - if use_lazy_init: - org_model = ctx.materialize(org_model) + ctx.materialize(org_model) + org_model = org_model.cuda() org_optimizer = Adam(org_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster @@ -142,11 +137,12 @@ def _criterion(outputs, inputs): data = data_gen_fn() sharded_model.train() if booster.plugin.stage_manager is not None: - data = { - k: v.to('cuda').repeat(*([4] + [1] * - (v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v - for k, v in data.items() - } + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + data_iter = iter([data]) sharded_output = booster.execute_pipeline(data_iter, sharded_model, @@ -176,7 +172,8 @@ def check_output_hidden_state(org_output: Tensor, sharded_output: Tensor, stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, - rtol: float = 1e-3): + rtol: float = 1e-3, + dim: int = 0): org_hidden_state = org_output.last_hidden_state @@ -184,7 +181,7 @@ def check_output_hidden_state(org_output: Tensor, sharded_hidden_state = sharded_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) + sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim) assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index e11bcf92ea3c..d5a4ce083e2b 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,57 +3,101 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - # do backward - org_loss.backward() - shard_loss.backward() + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - assert torch.allclose(org_loss, shard_loss, - atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + + if org_model.__class__.__name__ == 'BloomModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'BloomModel': bloom = org_model - sharded_bloom = sharded_model + sharded_bloom = sharded_model.unwrap() else: bloom = org_model.transformer - sharded_bloom = sharded_model.transformer + sharded_bloom = sharded_model.unwrap().transformer # check grad - col_layer_for_check = ['h[0].self_attention.query_key_value'] - row_layer_for_check = ['h[0].self_attention.dense'] - check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) - - -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused, - use_lazy_init): + row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] + col_layer_for_check = ['h[0].self_attention.dense'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_bloom_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bloom(): - spawn(check_bloom, 2) + spawn(check_bloom, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py deleted file mode 100644 index 6695e8a687bd..000000000000 --- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py +++ /dev/null @@ -1,90 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.shard import ShardConfig -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): - policy = get_autopolicy(model) - policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - policy.set_shard_config(model_config) - layers = policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 0 + 2 - else: - if name == 'transformers_bloom': - assert len(layers) == 1 + 1 - elif name == 'transformers_bloom_for_token_classification': - assert len(layers) == 1 + 3 - else: - assert len(layers) == 1 + 2 - - -def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): - if stage_manager.stage == 0: - x = torch.randint(0, 1000, (1, 3)).cuda() - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (1, 3, 64) - else: - attention_mask = torch.ones((1, 3)).cuda() - hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0].shape[0] == 1 - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_bloom -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - check_bloom_model_policy(name, org_model, stage_manager) - check_bloom_model_pipeline_forward(name, sharded_model, stage_manager) - - torch.cuda.empty_cache() - - -def check_bloom(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_bloom_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_bloom(): - spawn(check_bloom, 2) - - -if __name__ == "__main__": - test_bloom() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index c455a99d26ce..69e63ffc854e 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -1,99 +1,126 @@ -import copy -import os - import pytest import torch +from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) - # do backward - org_loss.backward() - shard_loss.backward() + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + if org_model.__class__.__name__ == 'ChatGLMModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'ChatGLMModel': chatglm_model = org_model - shard_chatglm_model = sharded_model + shard_chatglm_model = sharded_model.unwrap() else: chatglm_model = org_model.transformer - shard_chatglm_model = sharded_model.transformer - - # check attention grad - org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad - shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad - shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight + shard_chatglm_model = sharded_model.unwrap().transformer + + # check grad + row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] + col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=0, + verbose=False) + + check_grad(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=1e-4, + rtol=1e-3, + dim=1, + verbose=False) - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - - # check embedding weights - org_grad = chatglm_model.embedding.word_embeddings.weight.grad - shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad - shard_weight = shard_chatglm_model.embedding.word_embeddings.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + torch.cuda.empty_cache() - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_chatglm_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # create new model - org_model = model_fn().cuda() - - # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) - model_copy = copy.deepcopy(org_model) - shard_former = ShardFormer(shard_config=shard_config) - if name == "transformers_chatglm": - sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy()) - else: - sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()) - sharded_model = sharded_model.cuda() - - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -107,7 +134,7 @@ def check_chatglm(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_chatglm(): - spawn(check_chatglm, 2) + spawn(check_chatglm, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py deleted file mode 100644 index ee474ac7be3b..000000000000 --- a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py +++ /dev/null @@ -1,86 +0,0 @@ -import copy -import os - -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # create new model for test - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids = inputs['input_ids'] - hidden_size = 64 - batch_size, seq_len = input_ids.shape - hidden_state_shape = (seq_len, batch_size, hidden_size) - if name == "transformers_chatglm": - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy()) - if stage_manager.is_last_stage(): - hidden_states = torch.randn(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - outputs = sharded_model(**inputs) - if stage_manager.is_last_stage(): - assert outputs[0].shape == hidden_state_shape - - else: - assert outputs['hidden_states'].shape == hidden_state_shape - - if name == "transformers_chatglm_for_conditional_generation": - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init, - ChatGLMForConditionalGenerationPolicy()) - if stage_manager.is_last_stage(): - hidden_states = torch.randn(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - outputs = sharded_model(**inputs) - if stage_manager.is_last_stage(): - assert outputs[0].shape == (batch_size, seq_len, 65024) - else: - assert outputs['hidden_states'].shape == hidden_state_shape - - torch.cuda.empty_cache() - - -def check_chatglm(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_chatglm_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_chatglm(): - spawn(check_chatglm, 4) - - -if __name__ == "__main__": - test_chatglm() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ead14ab111e6..c5f8d22f18c9 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,69 +2,139 @@ import pytest import torch +from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - # forward check - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - # run backward - org_loss.backward() - shard_loss.backward() + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + if org_model.__class__.__name__ == 'LlamaModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model - if hasattr(org_model, 'model'): - llama_model = org_model.model - shard_llama_model = sharded_model.model - else: + if org_model.__class__.__name__ == 'LlamaModel': llama_model = org_model - shard_llama_model = sharded_model + shard_llama_model = sharded_model.unwrap() + else: + llama_model = org_model.model + shard_llama_model = sharded_model.unwrap().model # check grad - col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] - row_layer_for_check = ['layers[0].self_attn.o_proj'] - check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False) - check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False) + row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] + col_layer_for_check = ['layers[0].self_attn.o_proj'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(llama_model, + shard_llama_model, + row_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-4, + dim=0, + verbose=False) + check_grad(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-4, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=1e-4, + rtol=1e-3, + dim=1, + verbose=False) + + torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init): +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False +}]) +def run_llama_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt2_llama() + run_llama_test() @pytest.mark.dist diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py deleted file mode 100644 index 6f1f0cb34508..000000000000 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.shard import ShardConfig -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): - policy = get_autopolicy(model) - policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - policy.set_shard_config(model_config) - layers = policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 2 + 1 - else: - if name == "transformers_llama": - assert len(layers) == 2 + 1 - else: - assert len(layers) == 2 + 2 - - -def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): - x = torch.randint(0, 1000, (2, 3)).cuda() - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0] is not None - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_llama -def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - check_llama_model_policy(name, org_model, stage_manager) - check_llama_model_pipeline_forward(name, sharded_model, stage_manager) - - torch.cuda.empty_cache() - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_llama_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, 2) - - -if __name__ == "__main__": - test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 99a278d4303a..d8fa8104bb07 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -1,64 +1,129 @@ -import copy import os import pytest import torch +from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5) +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - # run backward - org_loss.backward() - shard_loss.backward() + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + if org_model.__class__.__name__ == 'OPTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model - if hasattr(org_model, 'model'): - opt_model = org_model.model - shard_opt_model = sharded_model.model - else: + if org_model.__class__.__name__ == 'OPTModel': opt_model = org_model - shard_opt_model = sharded_model + shard_opt_model = sharded_model.unwrap() + else: + opt_model = org_model.model + shard_opt_model = sharded_model.unwrap().model # check grad - col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] - row_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) - check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) - - -@parameterize('use_lazy_init', [False, True]) -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, - enable_jit_fused): + row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] + col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=0, + verbose=False) + check_grad(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=1e-3, + rtol=1e-3, + dim=1, + verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_opt_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py b/tests/test_shardformer/test_model/test_shard_opt_pipeline.py deleted file mode 100644 index 0684418d0a8d..000000000000 --- a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_opt -def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids, _ = inputs['input_ids'], inputs['attention_mask'] - batch_size, seq_len = input_ids.shape - hidden_size = 128 - hidden_state_shape = (batch_size, seq_len, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - - hidden_states = torch.zeros(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - sharded_model.train() - - output = sharded_model(**inputs) - if stage_manager.is_last_stage(): - assert output[0] is not None - else: - assert output['hidden_states'].shape == hidden_state_shape - torch.cuda.empty_cache() - - -def check_opt(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_opt_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_opt(): - spawn(check_opt, 4) - - -if __name__ == "__main__": - test_opt() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index d179c8a8ee32..8a78d7c2b8ce 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -1,60 +1,127 @@ -import os - import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3) + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - # do backward - org_loss.backward() - shard_loss.backward() + if org_model.__class__.__name__ == 'ViTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'ViTModel': vit_model = org_model - shard_vit_model = sharded_model + shard_vit_model = sharded_model.unwrap() else: vit_model = org_model.vit - shard_vit_model = sharded_model.vit + shard_vit_model = sharded_model.unwrap().vit # check grad - col_layer_for_check = ['encoder.layer[0].attention.attention.query'] - row_layer_for_check = ['encoder.layer[0].attention.output.dense'] - check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) - check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False) + row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] + col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(vit_model, + shard_vit_model, + row_layer_for_check, + tp_group, + atol=1e-5, + rtol=1e-3, + dim=0, + verbose=False) + check_grad(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=1e-5, + rtol=1e-3, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=5e-3, + rtol=1e-3, + dim=1, + verbose=False) + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_vit_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -68,7 +135,7 @@ def check_vit(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): - spawn(check_vit, 2) + spawn(check_vit, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py deleted file mode 100644 index 114992a2a2a5..000000000000 --- a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py +++ /dev/null @@ -1,74 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_vit -def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - pixel_values = inputs['pixel_values'] - batch_size = len(pixel_values) - hidden_size = 768 - hidden_state_shape = (batch_size, 197, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - hidden_states = torch.randn(*hidden_state_shape).cuda() - # inputs['pixel_values'] = None - inputs['hidden_states'] = hidden_states - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - sharded_model.train() - - output = sharded_model(**inputs) - if stage_manager.is_last_stage(): - if name != 'transformers_vit': - assert output.loss is not None - else: - assert output['hidden_states'].shape == hidden_state_shape, \ - f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' - - torch.cuda.empty_cache() - - -def check_vit(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_vit_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_vit(): - spawn(check_vit, 4) - - -if __name__ == "__main__": - test_vit()