diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6197be9d1c8d..20870a3c23a1 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors use_safetensors (bool): whether to use safetensors to save the checkpoint. """ # Move all tensors in the state_dict to CPU before saving to avoid serialization issues - state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) + state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict) if use_safetensors: assert is_safetensors_available(), "safetensors is not available." diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 7b25f3e7489d..013b0f06185d 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -6,11 +6,7 @@ import torch import torch.nn as nn -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -137,6 +133,7 @@ def glide_llama_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -147,57 +144,43 @@ def glide_llama_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - position_ids = position_ids.unsqueeze(0) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -212,6 +195,7 @@ def glide_llama_model_forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -230,7 +214,9 @@ def glide_llama_model_forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index c49458dbdf55..aa75bab115a7 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -738,7 +738,10 @@ def gpt2_for_sequence_classification_forward( sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 4f4cec8bc81f..facd2fcafbae 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -32,6 +32,7 @@ def _get_attention_mask( hidden_states: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], attention_mask: Optional[torch.FloatTensor], + use_flash_attention_2: bool = False, ) -> Optional[Union[torch.Tensor, dict]]: batch_size, seq_len = hidden_states.shape[:2] past_key_values_length = 0 @@ -47,7 +48,7 @@ def _get_attention_mask( attention_mask, is_causal=True, ) - elif attention_mask is not None: + elif use_flash_attention_2 and attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) @@ -162,7 +163,9 @@ def gptj_model_forward( output_shape = input_shape + (hidden_states.size(-1),) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -419,7 +422,10 @@ def gptj_for_sequence_classification_forward( sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( @@ -712,7 +718,9 @@ def forward( hidden_states = self.drop(hidden_states) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) @@ -886,7 +894,9 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) if self.gradient_checkpointing and self.training: if use_cache: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 01d10c8dcf95..f47be48eea35 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,11 +7,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.cache_utils import Cache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -21,7 +17,7 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, - apply_rotary_pos_emb, + StaticCache, repeat_kv, ) from transformers.utils import logging @@ -55,6 +51,7 @@ def llama_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -67,6 +64,11 @@ def llama_model_forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." + ) + use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -83,14 +85,24 @@ def llama_model_forward( device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device - seq_length_with_past = seq_length - past_key_values_length = 0 + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + + seq_length_with_past = seq_length + past_seen_tokens # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -103,18 +115,8 @@ def llama_model_forward( logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage @@ -129,28 +131,9 @@ def llama_model_forward( is_causal=True, ) else: - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - ) + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) - if self.gradient_checkpointing and self.training: + if self.gradient_checkpointing and self.training and use_cache: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -190,6 +173,7 @@ def llama_model_forward( past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -199,6 +183,7 @@ def llama_model_forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -249,6 +234,7 @@ def llama_for_causal_lm_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -306,6 +292,7 @@ def llama_for_causal_lm_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -368,6 +355,7 @@ def llama_for_sequence_classification_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -401,6 +389,7 @@ def llama_for_sequence_classification_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -486,6 +475,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -520,13 +510,14 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -562,6 +553,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -572,41 +564,40 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + # embed positions hidden_states = inputs_embeds # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -625,43 +616,38 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -672,7 +658,11 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -700,6 +690,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -744,6 +735,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -789,6 +781,8 @@ def forward( def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -797,6 +791,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is ring @@ -835,18 +830,14 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -854,18 +845,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -903,7 +885,7 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) def forward( - self, + self: LlamaModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -913,6 +895,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -924,63 +907,43 @@ def forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one" + ) - seq_length_with_past = seq_length - past_key_values_length = 0 + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - # modify past_key_values_length when using sequence parallel - past_key_values_length *= sp_size - seq_length_with_past = seq_length_with_past + past_key_values_length + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + if position_ids is None: + position_ids = cache_position.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) - hidden_states = inputs_embeds - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -990,14 +953,12 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions) + return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions) return custom_forward @@ -1013,15 +974,20 @@ def custom_forward(*inputs): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = ( + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, Cache) + else next_decoder_cache + ) if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 5f96ebe3d5cd..310c2d8e233a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -4,7 +4,10 @@ import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -77,7 +80,7 @@ def mistral_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -97,9 +100,18 @@ def mistral_model_forward( is_causal=True, ) else: - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -462,7 +474,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -481,9 +493,18 @@ def forward( is_causal=True, ) else: - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6d7df963a3a0..cf925983be4e 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -17,6 +17,7 @@ SequenceClassifierOutput, ) from transformers.models.whisper.modeling_whisper import ( + _HIDDEN_STATES_START_POSITION, WhisperDecoder, WhisperEncoder, WhisperForAudioClassification, @@ -166,6 +167,7 @@ def forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -199,9 +201,13 @@ def forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -599,6 +605,7 @@ def whisper_decoder_forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -716,9 +723,13 @@ def whisper_decoder_forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -841,6 +852,7 @@ def whisper_model_forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -944,6 +956,7 @@ def whisper_model_forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -986,6 +999,7 @@ def whisper_for_conditional_generation_forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1048,6 +1062,7 @@ def whisper_for_conditional_generation_forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1118,6 +1133,12 @@ def whisper_for_audio_classification_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + + if self.config.use_weighted_layer_sum: + output_hidden_states = True + elif output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # audio_classification only holds encoder @@ -1138,7 +1159,8 @@ def whisper_for_audio_classification_forward( return encoder_outputs if self.config.use_weighted_layer_sum: - hidden_states = torch.stack(encoder_outputs, dim=1) + hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 3315eb1e9256..c394d911e289 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -34,15 +34,11 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel - - ATTN_IMPLEMENTATION = { - "eager": GPTJAttention, - } + from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 621982f29058..c5a0277a5783 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -42,11 +42,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: MistralDecoderLayer, MistralFlashAttention2, MistralModel, + MistralSdpaAttention, ) ATTN_IMPLEMENTATION = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, } policy = {} diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fa88501ef968..27bbc3769448 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,7 +16,7 @@ ray sentencepiece google protobuf -transformers>=4.36.2,<4.40.0 +transformers==4.39.3 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0 diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 8237384c03fd..57a82647d49b 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): torch.manual_seed(10) TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers - x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) - x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + + position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) + emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + + cos, sin = emb(x0, position_ids) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) + cos = cos.reshape((TOTAL_TOKENS, -1)) + sin = sin.reshape((TOTAL_TOKENS, -1)) cos_2 = cos[:, : D // 2] sin_2 = sin[:, : D // 2] - position_ids = torch.arange(TOTAL_TOKENS) - embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) - embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D) + embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2) + embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2) assert torch.allclose(embd_x0, embd_stimulated_x) # create data diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 570093693447..78b7ba81c12b 100644 --- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin): def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers - x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) - x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) + x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) + cos, sin = emb(x0, position_ids) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) + cos = cos.reshape((TOTAL_TOKENS, -1)) + sin = sin.reshape((TOTAL_TOKENS, -1)) cos_2 = cos[:, :32] sin_2 = sin[:, :32] - position_ids = torch.arange(TOTAL_TOKENS) - embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) - embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D) + embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2) + embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2) assert torch.allclose(embd_x0, embd_stimulated_x) # create data