Skip to content

Commit

Permalink
Merge branch 'main' into soft_cap_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Sep 22, 2024
2 parents 9a68b8f + d7c7822 commit 1a4123a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
3 changes: 3 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def get_qkv(
'prev_layer_key_value is None, cannot reuse_prev_layer_kv.',
)
key, value = prev_layer_key_value
if self.attn_impl == 'torch':
key = rearrange(key, 'b h d s -> b s (h d)')
value = rearrange(value, 'b h s d -> b s (h d)')

query = self.Wq(x)
if self.clip_qkv:
Expand Down
19 changes: 14 additions & 5 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def gen_bias(attn_impl: str):
alibi=alibi,
alibi_bias_max=8,
)
if attn_impl != 'flash' and attn_uses_sequence_id and sequence_id is not None:
if attn_impl == 'torch' and attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = apply_sequence_id(
attn_bias,
Expand Down Expand Up @@ -561,16 +561,18 @@ def test_grouped_query_invalid_heads():
},
}],
)
@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_reuse_prev_layer_kv_cache(
pos_emb_config: dict,
attn_impl: str,
device: str = 'cuda',
):
"""Checks reusing previous layer's kv cache."""
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']

cfg = {
'attn_impl': 'flash',
'attn_impl': attn_impl,
'd_model': 64,
'n_heads': 4,
'attn_pdrop': 0,
Expand Down Expand Up @@ -630,14 +632,21 @@ def gen_bias(attn_impl: str):
alibi=alibi,
alibi_bias_max=8,
)
if attn_impl == 'torch':
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = apply_sequence_id(
attn_bias,
sequence_id, # type: ignore
s,
)

return attn_bias

attention_mask_in_length = gen_attention_mask_in_length(
sequence_id=sequence_id,
S=s,
attn_uses_sequence_id=True,
attn_impl='flash',
attn_impl=attn_impl,
attention_mask=attention_mask,
)

Expand All @@ -656,7 +665,7 @@ def gen_bias(attn_impl: str):
x1.requires_grad = True

with torch.autocast(x0.device.type):
attn_bias_0 = gen_bias('flash')
attn_bias_0 = gen_bias(attn_impl)
alibi_slopes_0 = None
if alibi:
alibi_slopes_0 = gen_slopes(
Expand Down Expand Up @@ -703,7 +712,7 @@ def gen_bias(attn_impl: str):
flash_attn_padding_info=flash_attn_padding_info,
alibi_slopes=alibi_slopes_0,
)
attn_bias_1 = gen_bias('flash')
attn_bias_1 = gen_bias(attn_impl)
alibi_slopes_1 = None
if alibi:
alibi_slopes_1 = gen_slopes(
Expand Down

0 comments on commit 1a4123a

Please sign in to comment.