From 2b2f3d844711fe9c797ecd6e6a8c0d400ac89cd3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 17 Jan 2024 07:09:30 +0000 Subject: [PATCH 1/5] .. --- llmfoundry/models/layers/attention.py | 62 +++++++++++++++++++-------- llmfoundry/models/layers/blocks.py | 13 +++++- llmfoundry/models/mpt/modeling_mpt.py | 10 ++++- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0fb6c0a042..0ccca0774b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -232,10 +232,13 @@ def flash_attn_fn( should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, + return_indices: bool = False, + indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: - from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip + from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip except: raise RuntimeError( 'Please install flash-attn==1.0.9 or flash-attn==2.3.6') @@ -267,26 +270,33 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + if indices_tuple is None: + if attention_mask_in_length is None: + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + + query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + query, query_padding_mask) + query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + key_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + key, key_padding_mask) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) - value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + value_unpad, indices_v, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + else: + indices_q, indices_k, indices_v = indices_tuple + query_unpad = index_first_axis(rearrange(query, "b s ... -> (b s) ..."), indices_q) + key_unpad = index_first_axis(rearrange(key, "b s ... -> (b s) ..."), indices_k) + value_unpad = index_first_axis(rearrange(value, "b s ... -> (b s) ..."), indices_v) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( not should_repeat_kv_for_gqa): @@ -367,6 +377,8 @@ def flash_attn_fn( output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) + if return_indices: + return output, None, past_key_value, (indices_q, indices_k, indices_v) return output, None, past_key_value @@ -601,6 +613,9 @@ def forward( needs_weights: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + return_indices: bool = False, + indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -671,9 +686,11 @@ def forward( 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, + 'return_indices': return_indices, + 'indices_tuple': indices_tuple } - context, attn_weights, past_key_value = self.attn_fn( + attn_fn_output = self.attn_fn( query, key, value, @@ -690,6 +707,13 @@ def forward( **extra_attn_kwargs, ) + if return_indices: + context, attn_weights, past_key_value, indices_tuple = attn_fn_output + else: + context, attn_weights, past_key_value = attn_fn_output + + if return_indices: + return self.out_proj(context), attn_weights, past_key_value, indices_tuple return self.out_proj(context), attn_weights, past_key_value diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index e5032998dc..140c68e9b6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -124,10 +124,13 @@ def forward( output_attentions: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + return_indices: bool = False, + indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) - b, attn_weights, past_key_value = self.attn( + attn_output = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, @@ -137,7 +140,13 @@ def forward( needs_weights=output_attentions, attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + return_indices=return_indices, + indices_tuple=indices_tuple, ) + if return_indices: + b, attn_weights, past_key_value, indices_tuple = attn_output + else: + b, attn_weights, past_key_value = attn_output x = x + self.resid_attn_dropout(b) m = x if self.norm_2 is not None: @@ -152,4 +161,6 @@ def forward( assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) + if return_indices: + return x, attn_weights, past_key_value, indices_tuple return x, attn_weights, past_key_value diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8b14c72f62..31d4a7f8c0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -623,13 +623,15 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + return_indices = self.attn_impl == 'flash' # TODO: Make this a config option + indices_tuple=None for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) past_key_value = (past_key_values[b_idx] if past_key_values is not None else None) - x, attn_weights, present = block( + block_output = block( x, past_key_value=past_key_value, attn_bias=attn_bias, @@ -639,7 +641,13 @@ def forward( output_attentions=bool(output_attentions), attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + return_indices=return_indices, + indices_tuple=indices_tuple, ) + if return_indices: + x, attn_weights, present, indices_tuple = block_output + else: + x, attn_weights, present = block_output if presents is not None: presents += (present,) From 25795b5aa6f89b6c3e106cdb9bad2e2080c6128d Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 17 Jan 2024 07:32:43 +0000 Subject: [PATCH 2/5] undoing prev commit --- llmfoundry/models/layers/attention.py | 62 ++++++++------------------- llmfoundry/models/layers/blocks.py | 13 +----- llmfoundry/models/mpt/modeling_mpt.py | 10 +---- 3 files changed, 21 insertions(+), 64 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0ccca0774b..0fb6c0a042 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -232,13 +232,10 @@ def flash_attn_fn( should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, - return_indices: bool = False, - indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: - from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip + from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: raise RuntimeError( 'Please install flash-attn==1.0.9 or flash-attn==2.3.6') @@ -270,33 +267,26 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if indices_tuple is None: - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + if attention_mask_in_length is None: + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + query, query_padding_mask) + query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( + key, key_padding_mask) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, indices_v, _, _ = unpadding_function(value, key_padding_mask) - value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - else: - indices_q, indices_k, indices_v = indices_tuple - query_unpad = index_first_axis(rearrange(query, "b s ... -> (b s) ..."), indices_q) - key_unpad = index_first_axis(rearrange(key, "b s ... -> (b s) ..."), indices_k) - value_unpad = index_first_axis(rearrange(value, "b s ... -> (b s) ..."), indices_v) + value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( not should_repeat_kv_for_gqa): @@ -377,8 +367,6 @@ def flash_attn_fn( output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) - if return_indices: - return output, None, past_key_value, (indices_q, indices_k, indices_v) return output, None, past_key_value @@ -613,9 +601,6 @@ def forward( needs_weights: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - return_indices: bool = False, - indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -686,11 +671,9 @@ def forward( 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, - 'return_indices': return_indices, - 'indices_tuple': indices_tuple } - attn_fn_output = self.attn_fn( + context, attn_weights, past_key_value = self.attn_fn( query, key, value, @@ -707,13 +690,6 @@ def forward( **extra_attn_kwargs, ) - if return_indices: - context, attn_weights, past_key_value, indices_tuple = attn_fn_output - else: - context, attn_weights, past_key_value = attn_fn_output - - if return_indices: - return self.out_proj(context), attn_weights, past_key_value, indices_tuple return self.out_proj(context), attn_weights, past_key_value diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 140c68e9b6..e5032998dc 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -124,13 +124,10 @@ def forward( output_attentions: bool = False, attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - return_indices: bool = False, - indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) - attn_output = self.attn( + b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, @@ -140,13 +137,7 @@ def forward( needs_weights=output_attentions, attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, - return_indices=return_indices, - indices_tuple=indices_tuple, ) - if return_indices: - b, attn_weights, past_key_value, indices_tuple = attn_output - else: - b, attn_weights, past_key_value = attn_output x = x + self.resid_attn_dropout(b) m = x if self.norm_2 is not None: @@ -161,6 +152,4 @@ def forward( assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) - if return_indices: - return x, attn_weights, past_key_value, indices_tuple return x, attn_weights, past_key_value diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 31d4a7f8c0..8b14c72f62 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -623,15 +623,13 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - return_indices = self.attn_impl == 'flash' # TODO: Make this a config option - indices_tuple=None for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) past_key_value = (past_key_values[b_idx] if past_key_values is not None else None) - block_output = block( + x, attn_weights, present = block( x, past_key_value=past_key_value, attn_bias=attn_bias, @@ -641,13 +639,7 @@ def forward( output_attentions=bool(output_attentions), attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, - return_indices=return_indices, - indices_tuple=indices_tuple, ) - if return_indices: - x, attn_weights, present, indices_tuple = block_output - else: - x, attn_weights, present = block_output if presents is not None: presents += (present,) From 9bc62c94229892bc04d2cd8a0cff63f8ab5c7a40 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 5 Feb 2024 02:08:15 +0000 Subject: [PATCH 3/5] fixing the gen_attention_mask_in_length function to handle the case when sequence id is -1 due to attention masking --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++++ tests/models/layers/test_flash_triton_torch.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2177124740..d3bf25bf82 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -212,6 +212,11 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, raise ValueError( f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' ) + if attention_mask is not None: + # -1 is used to pad the sequence_id where attention mask is False. We replace those -1 with 0 to prevent + # `torch.nn.functional.one_hot(sequence_id)` in the next line from failing. We apply the attention mask + # again after the one_hot operation. + sequence_id = sequence_id.masked_fill(~attention_mask, 0) attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: attention_mask_in_length = attention_mask_in_length.masked_fill( diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index d409486cc6..4fdb34439d 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -134,6 +134,9 @@ def test_attn_impl(attn_impl_0: str, # zero out the last third of the attention mask # to simulate padding attention_mask[:, -s // 3:] = 0 + sequence_id = sequence_id.masked_fill( + ~attention_mask, -1 + ) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249 def gen_bias(attn_impl: str): causal = True From 45d832fa87ef3094269ca0c7f75e43e9a0606f33 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sun, 4 Feb 2024 18:15:23 -0800 Subject: [PATCH 4/5] Update modeling_mpt.py --- llmfoundry/models/mpt/modeling_mpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d3bf25bf82..79dc8c7f25 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -213,9 +213,9 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' ) if attention_mask is not None: - # -1 is used to pad the sequence_id where attention mask is False. We replace those -1 with 0 to prevent - # `torch.nn.functional.one_hot(sequence_id)` in the next line from failing. We apply the attention mask - # again after the one_hot operation. + # -1 is used to pad the sequence_id where attention mask is False (https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249). + # We replace those -1 with 0 to prevent `torch.nn.functional.one_hot(sequence_id)` in the next line from failing. + # We apply the attention mask again after the one_hot operation. sequence_id = sequence_id.masked_fill(~attention_mask, 0) attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: From f1ded147d476fb452d4ccc52ad59c7e5745d7c6f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 5 Feb 2024 03:20:33 +0000 Subject: [PATCH 5/5] .. --- tests/models/layers/test_flash_triton_torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 4fdb34439d..4e1efa3f34 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -134,9 +134,10 @@ def test_attn_impl(attn_impl_0: str, # zero out the last third of the attention mask # to simulate padding attention_mask[:, -s // 3:] = 0 - sequence_id = sequence_id.masked_fill( - ~attention_mask, -1 - ) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249 + if sequence_id is not None: + sequence_id = sequence_id.masked_fill( + ~attention_mask, -1 + ) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249 def gen_bias(attn_impl: str): causal = True