Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 18, 2024
1 parent d844c5f commit 293dde2
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from omegaconf import OmegaConf as om

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding
from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info,
gen_rotary_embedding)


@pytest.mark.gpu
Expand Down Expand Up @@ -104,14 +105,20 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
attn_bias=None,
attention_mask=attention_mask,
rotary_emb_w_meta_info=dail_rope_w_meta_info,
is_causal=True)
is_causal=True,
flash_attn_padding_info=gen_flash_attn_padding_info(
batch_size, seq_len, 0, torch.device(device), None,
attention_mask))

y1, _, _ = attn1(x1,
past_key_value=None,
attn_bias=None,
attention_mask=attention_mask,
rotary_emb_w_meta_info=hf_rope_w_meta_info,
is_causal=True)
is_causal=True,
flash_attn_padding_info=gen_flash_attn_padding_info(
batch_size, seq_len, 0, torch.device(device), None,
attention_mask))

y0 *= attention_mask.unsqueeze(-1)
y1 *= attention_mask.unsqueeze(-1)
Expand Down

0 comments on commit 293dde2

Please sign in to comment.