Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent d524531 commit 2b2f3d8
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
62 changes: 43 additions & 19 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,13 @@ def flash_attn_fn(
should_repeat_kv_for_gqa: Optional[bool] = True,
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
return_indices: bool = False,
indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip
except:
raise RuntimeError(
'Please install flash-attn==1.0.9 or flash-attn==2.3.6')
Expand Down Expand Up @@ -267,26 +270,33 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
if indices_tuple is None:
if attention_mask_in_length is None:
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
key_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
key, key_padding_mask)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
value_unpad, indices_v, _, _ = unpadding_function(value, key_padding_mask)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
else:
indices_q, indices_k, indices_v = indices_tuple
query_unpad = index_first_axis(rearrange(query, "b s ... -> (b s) ..."), indices_q)
key_unpad = index_first_axis(rearrange(key, "b s ... -> (b s) ..."), indices_k)
value_unpad = index_first_axis(rearrange(value, "b s ... -> (b s) ..."), indices_v)

if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
Expand Down Expand Up @@ -367,6 +377,8 @@ def flash_attn_fn(
output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
seqlen)
if return_indices:
return output, None, past_key_value, (indices_q, indices_k, indices_v)
return output, None, past_key_value


Expand Down Expand Up @@ -601,6 +613,9 @@ def forward(
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
return_indices: bool = False,
indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand Down Expand Up @@ -671,9 +686,11 @@ def forward(
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'return_indices': return_indices,
'indices_tuple': indices_tuple
}

context, attn_weights, past_key_value = self.attn_fn(
attn_fn_output = self.attn_fn(
query,
key,
value,
Expand All @@ -690,6 +707,13 @@ def forward(
**extra_attn_kwargs,
)

if return_indices:
context, attn_weights, past_key_value, indices_tuple = attn_fn_output
else:
context, attn_weights, past_key_value = attn_fn_output

if return_indices:
return self.out_proj(context), attn_weights, past_key_value, indices_tuple
return self.out_proj(context), attn_weights, past_key_value


Expand Down
13 changes: 12 additions & 1 deletion llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def forward(
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
return_indices: bool = False,
indices_tuple: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
b, attn_weights, past_key_value = self.attn(
attn_output = self.attn(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
Expand All @@ -137,7 +140,13 @@ def forward(
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
return_indices=return_indices,
indices_tuple=indices_tuple,
)
if return_indices:
b, attn_weights, past_key_value, indices_tuple = attn_output
else:
b, attn_weights, past_key_value = attn_output
x = x + self.resid_attn_dropout(b)
m = x
if self.norm_2 is not None:
Expand All @@ -152,4 +161,6 @@ def forward(
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
if return_indices:
return x, attn_weights, past_key_value, indices_tuple
return x, attn_weights, past_key_value
10 changes: 9 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,13 +623,15 @@ def forward(

all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
return_indices = self.attn_impl == 'flash' # TODO: Make this a config option
indices_tuple=None
for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
assert all_hidden_states is not None # pyright
all_hidden_states = all_hidden_states + (x,)
past_key_value = (past_key_values[b_idx]
if past_key_values is not None else None)
x, attn_weights, present = block(
block_output = block(
x,
past_key_value=past_key_value,
attn_bias=attn_bias,
Expand All @@ -639,7 +641,13 @@ def forward(
output_attentions=bool(output_attentions),
attention_mask_in_length=attention_mask_in_length,
alibi_slopes=alibi_slopes,
return_indices=return_indices,
indices_tuple=indices_tuple,
)
if return_indices:
x, attn_weights, present, indices_tuple = block_output
else:
x, attn_weights, present = block_output
if presents is not None:
presents += (present,)

Expand Down

0 comments on commit 2b2f3d8

Please sign in to comment.