From e98a01d3e997c0a15d918833587532e8e8507595 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 17 Jan 2024 17:43:57 +0000 Subject: [PATCH] .. --- llmfoundry/models/layers/attention.py | 67 ++++++++------------------- llmfoundry/models/layers/blocks.py | 17 ++----- llmfoundry/models/mpt/modeling_mpt.py | 43 +++++++++++++---- 3 files changed, 55 insertions(+), 72 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0ccca0774b..de8b06197e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -228,15 +228,13 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, - return_indices: bool = False, - indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + del key_padding_mask try: from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip except: @@ -270,33 +268,20 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if indices_tuple is None: - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - - - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - - key_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - - value_unpad, indices_v, _, _ = unpadding_function(value, key_padding_mask) - value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - else: - indices_q, indices_k, indices_v = indices_tuple - query_unpad = index_first_axis(rearrange(query, "b s ... -> (b s) ..."), indices_q) - key_unpad = index_first_axis(rearrange(key, "b s ... -> (b s) ..."), indices_k) - value_unpad = index_first_axis(rearrange(value, "b s ... -> (b s) ..."), indices_v) + indices_q = flash_attn_padding_info['indices_q'] + indices_k = flash_attn_padding_info['indices_k'] + indices_v = flash_attn_padding_info['indices_v'] + cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'] + cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'] + max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] + max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] + + query_unpad = index_first_axis(rearrange(query, 'b s ... -> (b s) ...'), + indices_q) + key_unpad = index_first_axis(rearrange(key, 'b s ... -> (b s) ...'), + indices_k) + value_unpad = index_first_axis(rearrange(value, 'b s ... -> (b s) ...'), + indices_v) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( not should_repeat_kv_for_gqa): @@ -377,8 +362,6 @@ def flash_attn_fn( output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) - if return_indices: - return output, None, past_key_value, (indices_q, indices_k, indices_v) return output, None, past_key_value @@ -611,11 +594,8 @@ def forward( rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - return_indices: bool = False, - indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -682,15 +662,13 @@ def forward( extra_attn_kwargs = {} if self.attn_impl == 'flash': extra_attn_kwargs = { - 'attention_mask_in_length': attention_mask_in_length, 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, - 'return_indices': return_indices, - 'indices_tuple': indices_tuple + 'flash_attn_padding_info': flash_attn_padding_info, } - attn_fn_output = self.attn_fn( + context, attn_weights, past_key_value = self.attn_fn( query, key, value, @@ -707,13 +685,6 @@ def forward( **extra_attn_kwargs, ) - if return_indices: - context, attn_weights, past_key_value, indices_tuple = attn_fn_output - else: - context, attn_weights, past_key_value = attn_fn_output - - if return_indices: - return self.out_proj(context), attn_weights, past_key_value, indices_tuple return self.out_proj(context), attn_weights, past_key_value diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 140c68e9b6..036a4e7cd2 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -122,15 +122,12 @@ def forward( attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - return_indices: bool = False, - indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) - attn_output = self.attn( + b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, @@ -138,15 +135,9 @@ def forward( attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, - attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, - return_indices=return_indices, - indices_tuple=indices_tuple, + flash_attn_padding_info=flash_attn_padding_info, ) - if return_indices: - b, attn_weights, past_key_value, indices_tuple = attn_output - else: - b, attn_weights, past_key_value = attn_output x = x + self.resid_attn_dropout(b) m = x if self.norm_2 is not None: @@ -161,6 +152,4 @@ def forward( assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) - if return_indices: - return x, attn_weights, past_key_value, indices_tuple return x, attn_weights, past_key_value diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 31d4a7f8c0..750814e427 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -28,6 +28,7 @@ if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding from flash_attn.layers.rotary import \ RotaryEmbedding as DAILRotaryEmbedding except Exception as e: @@ -623,15 +624,43 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - return_indices = self.attn_impl == 'flash' # TODO: Make this a config option - indices_tuple=None + flash_attn_padding_info = {} + if self.attn_impl == 'flash': + if attention_mask_in_length is None: + if attention_mask is None: + past_key_len = past_key_values[0].shape[ + 1] if past_key_values is not None else 0 + attention_mask = torch.ones( + (x.shape[0], past_key_len + x.shape[1]), + dtype=torch.bool) + query_padding_mask = attention_mask[:, -x.shape[1]:] + unpadding_function = bert_padding.unpad_input + else: + attention_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + + _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + torch.zeros(1, 1), query_padding_mask) + _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + torch.zeros(1, 1), attention_mask) + _, indices_v, _, _ = unpadding_function(torch.zeros(1, 1), + attention_mask) + + flash_attn_padding_info['indices_q'] = indices_q + flash_attn_padding_info['indices_k'] = indices_k + flash_attn_padding_info['indices_v'] = indices_v + flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q + flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k + flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q + flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) past_key_value = (past_key_values[b_idx] if past_key_values is not None else None) - block_output = block( + x, attn_weights, present = block( x, past_key_value=past_key_value, attn_bias=attn_bias, @@ -639,15 +668,9 @@ def forward( attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), - attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, - return_indices=return_indices, - indices_tuple=indices_tuple, + flash_attn_padding_info=flash_attn_padding_info, ) - if return_indices: - x, attn_weights, present, indices_tuple = block_output - else: - x, attn_weights, present = block_output if presents is not None: presents += (present,)