Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 2b2f3d8 commit e98a01d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 72 deletions.
67 changes: 19 additions & 48 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,13 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
return_indices: bool = False,
indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
del key_padding_mask
try:
from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip
except:
Expand Down Expand Up @@ -270,33 +268,20 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if indices_tuple is None:
if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences


query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, indices_v, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
else:
indices_q, indices_k, indices_v = indices_tuple
query_unpad = index_first_axis(rearrange(query, "b s ... -> (b s) ..."), indices_q)
key_unpad = index_first_axis(rearrange(key, "b s ... -> (b s) ..."), indices_k)
value_unpad = index_first_axis(rearrange(value, "b s ... -> (b s) ..."), indices_v)
indices_q = flash_attn_padding_info['indices_q']
indices_k = flash_attn_padding_info['indices_k']
indices_v = flash_attn_padding_info['indices_v']
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']

query_unpad = index_first_axis(rearrange(query, 'b s ... -> (b s) ...'),
indices_q)
key_unpad = index_first_axis(rearrange(key, 'b s ... -> (b s) ...'),
indices_k)
value_unpad = index_first_axis(rearrange(value, 'b s ... -> (b s) ...'),
indices_v)

if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
Expand Down Expand Up @@ -377,8 +362,6 @@ def flash_attn_fn(
output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
seqlen)
if return_indices:
return output, None, past_key_value, (indices_q, indices_k, indices_v)
return output, None, past_key_value


Expand Down Expand Up @@ -611,11 +594,8 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
return_indices: bool = False,
indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -682,15 +662,13 @@ def forward(
extra_attn_kwargs = {}
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'return_indices': return_indices,
'indices_tuple': indices_tuple
'flash_attn_padding_info': flash_attn_padding_info,
}

attn_fn_output = self.attn_fn(
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
Expand All @@ -707,13 +685,6 @@ def forward(
**extra_attn_kwargs,
)

if return_indices:
context, attn_weights, past_key_value, indices_tuple = attn_fn_output
else:
context, attn_weights, past_key_value = attn_fn_output

if return_indices:
return self.out_proj(context), attn_weights, past_key_value, indices_tuple
return self.out_proj(context), attn_weights, past_key_value


Expand Down
17 changes: 3 additions & 14 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,31 +122,22 @@ def forward(
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
return_indices: bool = False,
indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
attn_output = self.attn(
b, attn_weights, past_key_value = self.attn(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
return_indices=return_indices,
indices_tuple=indices_tuple,
flash_attn_padding_info=flash_attn_padding_info,
)
if return_indices:
b, attn_weights, past_key_value, indices_tuple = attn_output
else:
b, attn_weights, past_key_value = attn_output
x = x + self.resid_attn_dropout(b)
m = x
if self.norm_2 is not None:
Expand All @@ -161,6 +152,4 @@ def forward(
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
if return_indices:
return x, attn_weights, past_key_value, indices_tuple
return x, attn_weights, past_key_value
43 changes: 33 additions & 10 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
from flash_attn import bert_padding
from flash_attn.layers.rotary import \
RotaryEmbedding as DAILRotaryEmbedding
except Exception as e:
Expand Down Expand Up @@ -623,31 +624,53 @@ def forward(

all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
return_indices = self.attn_impl == 'flash' # TODO: Make this a config option
indices_tuple=None
flash_attn_padding_info = {}
if self.attn_impl == 'flash':
if attention_mask_in_length is None:
if attention_mask is None:
past_key_len = past_key_values[0].shape[
1] if past_key_values is not None else 0
attention_mask = torch.ones(
(x.shape[0], past_key_len + x.shape[1]),
dtype=torch.bool)
query_padding_mask = attention_mask[:, -x.shape[1]:]
unpadding_function = bert_padding.unpad_input
else:
attention_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(1, 1), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(1, 1), attention_mask)
_, indices_v, _, _ = unpadding_function(torch.zeros(1, 1),
attention_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
flash_attn_padding_info['indices_v'] = indices_v
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
assert all_hidden_states is not None # pyright
all_hidden_states = all_hidden_states + (x,)
past_key_value = (past_key_values[b_idx]
if past_key_values is not None else None)
block_output = block(
x, attn_weights, present = block(
x,
past_key_value=past_key_value,
attn_bias=attn_bias,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
return_indices=return_indices,
indices_tuple=indices_tuple,
flash_attn_padding_info=flash_attn_padding_info,
)
if return_indices:
x, attn_weights, present, indices_tuple = block_output
else:
x, attn_weights, present = block_output
if presents is not None:
presents += (present,)

Expand Down

0 comments on commit e98a01d

Please sign in to comment.