Skip to content

Commit

Permalink
Update comment
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 30, 2023
1 parent 0582982 commit 2831c98
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def llama_attention_patch_torch(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be removed in the next transformers release >4.32.1.
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
Expand Down Expand Up @@ -188,7 +188,7 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be removed in the next transformers release >4.32.1.
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_attn_patch_integration(patch: str):
model.to('cuda')

with get_precision_context('amp_bf16'):
# We're just testing that flash attention 2 runs okay
# We're just testing that the attention patch runs okay
outputs = model(tokenized_input)
loss = outputs.loss
loss.backward()
Expand Down

0 comments on commit 2831c98

Please sign in to comment.