diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0fb6c0a042..0ccca0774b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -232,10 +232,13 @@ def flash_attn_fn( should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, + return_indices: bool = False, + indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: - from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip + from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip except: raise RuntimeError( 'Please install flash-attn==1.0.9 or flash-attn==2.3.6') @@ -267,26 +270,33 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + if indices_tuple is None: + if attention_mask_in_length is None: + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + + query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + query, query_padding_mask) + query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + key_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + key, key_padding_mask) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) - value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + value_unpad, indices_v, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + else: + indices_q, indices_k, indices_v = indices_tuple + query_unpad = index_first_axis(rearrange(query, "b s ... -> (b s) ..."), indices_q) + key_unpad = index_first_axis(rearrange(key, "b s ... -> (b s) ..."), indices_k) + value_unpad = index_first_axis(rearrange(value, "b s ... -> (b s) ..."), indices_v) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( not should_repeat_kv_for_gqa): @@ -367,6 +377,8 @@ def flash_attn_fn( output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) + if return_indices: + return output, None, past_key_value, (indices_q, indices_k, indices_v) return output, None, past_key_value @@ -601,6 +613,9 @@ def forward( needs_weights: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + return_indices: bool = False, + indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -671,9 +686,11 @@ def forward( 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, + 'return_indices': return_indices, + 'indices_tuple': indices_tuple } - context, attn_weights, past_key_value = self.attn_fn( + attn_fn_output = self.attn_fn( query, key, value, @@ -690,6 +707,13 @@ def forward( **extra_attn_kwargs, ) + if return_indices: + context, attn_weights, past_key_value, indices_tuple = attn_fn_output + else: + context, attn_weights, past_key_value = attn_fn_output + + if return_indices: + return self.out_proj(context), attn_weights, past_key_value, indices_tuple return self.out_proj(context), attn_weights, past_key_value diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index e5032998dc..140c68e9b6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -124,10 +124,13 @@ def forward( output_attentions: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + return_indices: bool = False, + indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) - b, attn_weights, past_key_value = self.attn( + attn_output = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, @@ -137,7 +140,13 @@ def forward( needs_weights=output_attentions, attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + return_indices=return_indices, + indices_tuple=indices_tuple, ) + if return_indices: + b, attn_weights, past_key_value, indices_tuple = attn_output + else: + b, attn_weights, past_key_value = attn_output x = x + self.resid_attn_dropout(b) m = x if self.norm_2 is not None: @@ -152,4 +161,6 @@ def forward( assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) + if return_indices: + return x, attn_weights, past_key_value, indices_tuple return x, attn_weights, past_key_value diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8b14c72f62..31d4a7f8c0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -623,13 +623,15 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + return_indices = self.attn_impl == 'flash' # TODO: Make this a config option + indices_tuple=None for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) past_key_value = (past_key_values[b_idx] if past_key_values is not None else None) - x, attn_weights, present = block( + block_output = block( x, past_key_value=past_key_value, attn_bias=attn_bias, @@ -639,7 +641,13 @@ def forward( output_attentions=bool(output_attentions), attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + return_indices=return_indices, + indices_tuple=indices_tuple, ) + if return_indices: + x, attn_weights, present, indices_tuple = block_output + else: + x, attn_weights, present = block_output if presents is not None: presents += (present,)