From 34e4a99cbb3447e65f172ef22978e45d16d6b2d9 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 17 Jan 2024 21:00:39 +0000 Subject: [PATCH] .. --- llmfoundry/models/layers/attention.py | 46 --------------------------- llmfoundry/models/mpt/modeling_mpt.py | 27 ++++++++++------ 2 files changed, 18 insertions(+), 55 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 83f0161c07..055bcdd018 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F import transformers from einops import rearrange from packaging import version @@ -86,51 +85,6 @@ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden.reshape(b, s, kv_n_heads * n_rep, d) -def get_indices_info(attention_mask: torch.Tensor): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), - (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def get_indices_for_concatenated_sequences( - attention_mask_in_length: torch.Tensor): - length = attention_mask_in_length.sum(dim=-1) - seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange( - seqlen, device=length.device, dtype=length.dtype).expand( - len(length), seqlen) < length.unsqueeze(1) - real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), - as_tuple=False).flatten() - seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] - indices = torch.nonzero(attention_mask_2d.flatten(), - as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), - (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - def scaled_multihead_dot_product_attention( query: torch.Tensor, key: torch.Tensor, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6d8f8bbae2..051db72af7 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -24,17 +24,23 @@ from composer.models import HuggingFaceModel from composer.utils import dist -from llmfoundry.models.layers.attention import ( - get_indices_for_concatenated_sequences, get_indices_info, - is_flash_v2_installed) +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding from flash_attn.layers.rotary import \ RotaryEmbedding as DAILRotaryEmbedding except Exception as e: raise e +if is_flash_v1_installed(): + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding + except Exception as e: + raise e + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -232,15 +238,18 @@ def gen_flash_attn_padding_info( key_padding_mask = torch.ones((bsz, past_key_len + S), dtype=torch.bool).to(device=device) query_padding_mask = key_padding_mask[:, -S:] - indices_function = get_indices_info + unpadding_function = bert_padding.unpad_input else: key_padding_mask = attention_mask_in_length query_padding_mask = attention_mask_in_length - indices_function = get_indices_for_concatenated_sequences - - indices_q, cu_seqlens_q, max_seqlen_q = indices_function(query_padding_mask) - indices_k, cu_seqlens_k, max_seqlen_k = indices_function(key_padding_mask) - indices_v, _, _ = indices_function(key_padding_mask) + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + + _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + torch.zeros(bsz, S, 1).to(device=device), query_padding_mask) + _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + torch.zeros(bsz, S, 1).to(device=device), key_padding_mask) + _, indices_v, _, _ = unpadding_function( + torch.zeros(bsz, S, 1).to(device=device), key_padding_mask) flash_attn_padding_info['indices_q'] = indices_q flash_attn_padding_info['indices_k'] = indices_k