Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 0f95056 commit 5063149
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
46 changes: 46 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from einops import rearrange
from packaging import version
Expand Down Expand Up @@ -85,6 +86,51 @@ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def get_indices_info(attention_mask: torch.Tensor):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def get_indices_for_concatenated_sequences(
attention_mask_in_length: torch.Tensor):
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(
seqlen, device=length.device, dtype=length.dtype).expand(
len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(),
as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(),
as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down
20 changes: 9 additions & 11 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from composer.models import HuggingFaceModel
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.attention import (
get_indices_for_concatenated_sequences, get_indices_info,
is_flash_v2_installed)

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
from flash_attn import bert_padding
from flash_attn.layers.rotary import \
RotaryEmbedding as DAILRotaryEmbedding
except Exception as e:
Expand Down Expand Up @@ -231,18 +232,15 @@ def gen_flash_attn_padding_info(
key_padding_mask = torch.ones((bsz, past_key_len + S),
dtype=torch.bool).to(device=device)
query_padding_mask = key_padding_mask[:, -S:]
unpadding_function = bert_padding.unpad_input
indices_function = get_indices_info
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(1, 1, 1).to(device=device), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(1, 1, 1).to(device=device), key_padding_mask)
_, indices_v, _, _ = unpadding_function(torch.zeros(1, 1, 1),
key_padding_mask)
indices_function = get_indices_for_concatenated_sequences

indices_q, cu_seqlens_q, max_seqlen_q = indices_function(query_padding_mask)
indices_k, cu_seqlens_k, max_seqlen_k = indices_function(key_padding_mask)
indices_v, _, _ = indices_function(key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
Expand Down

0 comments on commit 5063149

Please sign in to comment.