Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompute flash attention padding info #880

Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
04dd334
Merge pull request #1 from mosaicml/main
ShashankMosaicML Oct 9, 2023
87b2fdc
Merge pull request #8 from mosaicml/main
ShashankMosaicML Oct 27, 2023
c9a42e4
Merge pull request #12 from mosaicml/main
ShashankMosaicML Nov 6, 2023
ddea9ee
Merge branch 'mosaicml:main' into main
ShashankMosaicML Nov 6, 2023
0bcd8ee
Merge pull request #13 from mosaicml/main
ShashankMosaicML Nov 8, 2023
f209b58
Merge pull request #14 from mosaicml/main
ShashankMosaicML Nov 14, 2023
ec4378d
Merge pull request #15 from mosaicml/main
ShashankMosaicML Nov 15, 2023
b436706
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 2, 2023
bcace03
..
ShashankMosaicML Dec 8, 2023
cf4aa58
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 11, 2023
7c35ce8
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 13, 2023
0a8ebfb
..
ShashankMosaicML Dec 15, 2023
6f18a33
..
ShashankMosaicML Dec 15, 2023
f42d585
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 16, 2023
2f3f53c
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 19, 2023
77b975f
..
ShashankMosaicML Dec 20, 2023
e28cfbe
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 1, 2024
800c6f8
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 2, 2024
922ef13
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 3, 2024
d36f5f7
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 5, 2024
d524531
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 17, 2024
2b2f3d8
..
ShashankMosaicML Jan 17, 2024
e98a01d
..
ShashankMosaicML Jan 17, 2024
5a9e1e8
..
ShashankMosaicML Jan 17, 2024
61d8ade
..
ShashankMosaicML Jan 17, 2024
e236305
..
ShashankMosaicML Jan 17, 2024
416525a
..
ShashankMosaicML Jan 17, 2024
77597a1
..
ShashankMosaicML Jan 17, 2024
09d9bdf
..
ShashankMosaicML Jan 17, 2024
0474e05
..
ShashankMosaicML Jan 17, 2024
0f95056
..
ShashankMosaicML Jan 17, 2024
5063149
..
ShashankMosaicML Jan 17, 2024
0f25b73
..
ShashankMosaicML Jan 17, 2024
c3d30f9
..
ShashankMosaicML Jan 17, 2024
34e4a99
..
ShashankMosaicML Jan 17, 2024
03113a9
..
ShashankMosaicML Jan 17, 2024
3351d23
..
ShashankMosaicML Jan 17, 2024
3d8cda8
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Jan 17, 2024
b227bcf
dummy data
ShashankMosaicML Jan 17, 2024
bd28b43
undoing last commit
ShashankMosaicML Jan 17, 2024
d844c5f
..
ShashankMosaicML Jan 17, 2024
293dde2
..
ShashankMosaicML Jan 18, 2024
18bf7ca
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Jan 18, 2024
11d3d70
..
ShashankMosaicML Jan 18, 2024
00bc72b
..
ShashankMosaicML Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,15 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del key_padding_mask
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
if flash_attn_padding_info is None:
raise ValueError('flash_attn_padding_info is required for flash attn.')
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
Expand Down Expand Up @@ -267,25 +270,24 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
indices_q = flash_attn_padding_info['indices_q']
indices_k = flash_attn_padding_info['indices_k']
indices_v = flash_attn_padding_info['indices_v']
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = bert_padding.index_first_axis(
rearrange(query, 'b s ... -> (b s) ...'), indices_q)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = bert_padding.index_first_axis(
rearrange(key, 'b s ... -> (b s) ...'), indices_k)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = bert_padding.index_first_axis(
rearrange(value, 'b s ... -> (b s) ...'), indices_v)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
Expand Down Expand Up @@ -599,8 +601,8 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -667,10 +669,10 @@ def forward(
extra_attn_kwargs = {}
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'flash_attn_padding_info': flash_attn_padding_info,
}

context, attn_weights, past_key_value = self.attn_fn(
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -135,8 +135,8 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
90 changes: 73 additions & 17 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@
from composer.models import HuggingFaceModel
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
from flash_attn import bert_padding
from flash_attn.layers.rotary import \
RotaryEmbedding as DAILRotaryEmbedding
except Exception as e:
raise e

if is_flash_v1_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
from flash_attn import bert_padding
except Exception as e:
raise e
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved

from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import PreTrainedModel, PreTrainedTokenizerBase
Expand Down Expand Up @@ -216,6 +224,45 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
return attention_mask_in_length


def gen_flash_attn_padding_info(
bsz: int,
S: int,
past_key_len: int,
device: torch.device,
attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None):
flash_attn_padding_info = {}
dummy_data = torch.ones((bsz, past_key_len + S),
dtype=torch.bool,
device=device)
if attention_mask_in_length is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
key_padding_mask = dummy_data
query_padding_mask = key_padding_mask[:, -S:]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
dummy_data[:, :S, None], query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
dummy_data[:, :, None], key_padding_mask)
_, indices_v, _, _ = unpadding_function(dummy_data[:, :, None],
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
flash_attn_padding_info['indices_v'] = indices_v
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
return flash_attn_padding_info


def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,
max_seq_len: int) -> torch.Tensor:
seq_len = sequence_id.shape[-1]
Expand Down Expand Up @@ -515,10 +562,12 @@ def forward(
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
elif input_ids is not None:
bsz = input_ids.size(0)
S = input_ids.size(1)
x = self.wte(input_ids)
input_device = input_ids.device
elif inputs_embeds is not None:
bsz = inputs_embeds.size(0)
S = inputs_embeds.size(1)
x = inputs_embeds
input_device = inputs_embeds.device
Expand All @@ -530,22 +579,23 @@ def forward(
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

rotary_emb_w_meta_info = None
if self.learned_pos_emb or self.rope:
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)

past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)

if self.learned_pos_emb or self.rope:
if self.learned_pos_emb and (S + past_position >
self.config.max_seq_len):
raise ValueError(
Expand Down Expand Up @@ -623,6 +673,12 @@ def forward(

all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
flash_attn_padding_info = {}
if self.attn_impl == 'flash':
flash_attn_padding_info = gen_flash_attn_padding_info(
bsz, S, past_position, x.device, attention_mask_in_length,
attention_mask)

for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
assert all_hidden_states is not None # pyright
Expand All @@ -637,8 +693,8 @@ def forward(
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
)
if presents is not None:
presents += (present,)
Expand Down
Loading
Loading