Skip to content

Commit

Permalink
modified is_flash_v2_installed
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 3, 2023
1 parent ac0fd40 commit e59f784
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


def is_flash_v2_installed(v2_version: str = '2.0.0'):
assert version.parse(v2_version) >= version.parse('2.0.0')
try:
import flash_attn as flash_attn
except:
Expand Down Expand Up @@ -597,7 +598,7 @@ def forward(
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
assert query.shape[:2] == key.shape[:2]
assert query.shape[:2] == key.shape[:2]
assert query.shape[:2] == value.shape[:2]
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)
Expand Down

0 comments on commit e59f784

Please sign in to comment.