diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 34b0652d1c7674..7e277d480a81b9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -19,7 +19,7 @@ # limitations under the License. """ PyTorch LLaMA model.""" import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict import torch import torch.nn.functional as F @@ -211,6 +211,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +def apply_rotary_pos_emb_unpad(q, key_states, cos, sin, position_ids_unpad): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(0, 1) # [seq_len, dim] + sin = sin.squeeze(0, 1) # [seq_len, dim] + cos = cos[position_ids_unpad] # [total_tokens, 1, dim] + sin = sin[position_ids_unpad] # [total_tokens, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + key_states.copy_((key_states * cos) + (rotate_half(key_states) * sin)) + return q_embed + class LlamaMLP(nn.Module): def __init__(self, config): @@ -323,9 +333,14 @@ def forward( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + flash_kwargs: Optional[Dict] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + + if flash_kwargs is not None and flash_kwargs["is_unpadded"]: + raise ValueError("Non flash does not support the unpadded path") + if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( @@ -478,89 +493,123 @@ def forward( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + flash_kwargs: Optional[Dict] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # LlamaFlashAttention attention does not support output_attentions output_attentions = False - bsz, q_len, _ = hidden_states.size() + if not flash_kwargs["is_unpadded"]: + bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # contains at least one padding token + if flash_kwargs["masking"]: + indices_k = flash_kwargs["indices_k"] + cu_seqlens_k = flash_kwargs["cu_seqlens_k"] + max_seqlen_in_batch_k = flash_kwargs["max_seqlen_in_batch_k"] + + key_states = index_first_axis(rearrange(key_states, "b s ... -> (b s) ..."), indices_k) + value_states = index_first_axis(rearrange(value_states, "b s ... -> (b s) ..."), indices_k) + + # In an ideal world, at least for the path q_len == kv_seq_len and q_len == 1, we should collect the + if q_len == kv_seq_len: + query_states = index_first_axis(rearrange(query_states, "b s ... -> (b s) ..."), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = flash_kwargs["cu_seqlens_q"] + indices_q = flash_kwargs["indices_q"] + query_states = query_states.squeeze(1) # [batch_size, 1, num_heads, head_dim] -> [batch_size, num_heads, head_dim] + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -q_len:] + query_states, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_states, padding_mask) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) + else: + attn_output = flash_attn_func(query_states, key_states, value_states, dropout_rate, causal=True) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + else: + # is_unpadded path + total_tokens, _ = hidden_states.size() - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states = self.q_proj(hidden_states).view(total_tokens, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(total_tokens, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(total_tokens, self.num_heads, self.head_dim) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + max_seqlen = flash_kwargs["max_seqlen_in_batch_k"] - past_key_value = (key_states, value_states) if use_cache else None + cos, sin = self.rotary_emb(value_states, seq_len=max_seqlen) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # TODO: llama does not have dropout in the config?? - # It is recommended to use dropout with FA according to the docs - # when training. - dropout_rate = 0.0 # if not self.training else self.attn_dropout - - # contains at least one padding token - if padding_mask is not None: - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) - key_states = index_first_axis(rearrange(key_states, "b s ... -> (b s) ..."), indices_k) - value_states = index_first_axis(rearrange(value_states, "b s ... -> (b s) ..."), indices_k) - - # In an ideal world, at least for the path q_len == kv_seq_len and q_len == 1, we should collect the - if q_len == kv_seq_len: - query_states = index_first_axis(rearrange(query_states, "b s ... -> (b s) ..."), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif q_len == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - bsz + 1, dtype=torch.int32, device=query_states.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_states = query_states.squeeze(1) - else: - # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -q_len:] - query_states, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_states, padding_mask) + # key_states modified in place + query_states = apply_rotary_pos_emb_unpad(query_states, key_states, cos, sin, position_ids) - attn_output_unpad = flash_attn_varlen_func( + # It would be nice to use rather the flash_attn_kvpacked_func interface, with a single nn.Linear to compute keys/values + attn_output = flash_attn_varlen_func( query_states, key_states, value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, + cu_seqlens_q=flash_kwargs["cu_seqlens_k"], + cu_seqlens_k=flash_kwargs["cu_seqlens_k"], + max_seqlen_q=flash_kwargs["max_seqlen_in_batch_k"], + max_seqlen_k=flash_kwargs["max_seqlen_in_batch_k"], dropout_p=0.0, softmax_scale=None, causal=True, ) - attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) - else: - attn_output = flash_attn_func(query_states, key_states, value_states, dropout_rate, causal=True) + attn_output = attn_output.reshape(-1, self.num_heads * self.head_dim) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -591,6 +640,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, padding_mask: Optional[torch.LongTensor] = None, + flash_kwargs: Optional[Dict] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -619,6 +669,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + flash_kwargs=flash_kwargs, ) hidden_states = residual + hidden_states @@ -770,6 +821,7 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + self._flash = getattr(config, "_flash_attn_2_enabled", False) # Initialize weights and apply final processing self.post_init() @@ -863,13 +915,39 @@ def forward( padding_mask = attention_mask else: padding_mask = None - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - + hidden_states = inputs_embeds + if not self._flash: + flash_kwargs = None + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + flash_kwargs = {} + flash_kwargs["masking"] = padding_mask is not None + + if padding_mask is not None: + if not use_cache: + hidden_states, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = unpad_input(hidden_states, padding_mask) + position_ids = position_ids.expand(batch_size, seq_length) + position_ids, _, _, _ = unpad_input(position_ids.unsqueeze(-1), padding_mask) + is_unpadded = True + flash_kwargs["is_unpadded"] = True + else: + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + flash_kwargs["is_unpadded"] = False + if seq_length == 1: + flash_kwargs["cu_seqlens_q"] = torch.arange( + batch_size + 1, dtype=torch.int32, device=input_ids.device + ) # There is a memcpy here, that is very bad. At least happening only once. + flash_kwargs["indices_q"] = flash_kwargs["cu_seqlens_q"][:-1] + flash_kwargs["indices_k"] = indices_k + flash_kwargs["cu_seqlens_k"] = cu_seqlens_k + flash_kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k + else: + flash_kwargs["is_unpadded"] = False + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -909,6 +987,7 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + flash_kwargs=flash_kwargs, ) hidden_states = layer_outputs[0] @@ -921,6 +1000,9 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) + if self._flash and padding_mask is not None and not use_cache: + hidden_states = pad_input(hidden_states, indices_k, batch_size, max_seqlen_in_batch_k) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,)