Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 18, 2024
1 parent 18bf7ca commit 11d3d70
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def flash_attn_fn(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
if key_padding_mask is not None:
raise ValueError('key_padding_mask should be None for flash attn.')
del key_padding_mask
if flash_attn_padding_info is None:
raise ValueError('flash_attn_padding_info is required for flash attn.')
Expand Down Expand Up @@ -668,6 +670,8 @@ def forward(

extra_attn_kwargs = {}
if self.attn_impl == 'flash':
if flash_attn_padding_info is not None:
key_padding_mask = None
extra_attn_kwargs = {
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
Expand Down

0 comments on commit 11d3d70

Please sign in to comment.