Skip to content

Commit

Permalink
fix API and revert temp change
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 10, 2023
1 parent dfee60e commit e4d0244
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ARG DEP_GROUPS

# Install and uninstall foundry to cache foundry requirements
# TEMPORARY CHANGE DO NOT MERGE
RUN git clone -b flash2-upstream https://github.com/mosaicml/llm-foundry.git
RUN git clone -b main https://github.com/mosaicml/llm-foundry.git
RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}"
RUN pip uninstall -y llm-foundry
RUN rm -rf llm-foundry
34 changes: 17 additions & 17 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,27 +296,27 @@ def flash_attn_fn(

if is_flash_v1_installed():
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
elif is_flash_v2_installed():
output_unpad = flash_attn_interface.flash_attn_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
Expand Down

0 comments on commit e4d0244

Please sign in to comment.