diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 5bcc3e35fca4..bd15ce2bdef8 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -79,10 +79,10 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Physical cache allocation + alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size) if verbose: - alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size) self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") - self._kv_caches = self._init_device_caches() + self._kv_caches = self._init_device_caches(alloc_shape) self.total_physical_cache_size_in_bytes = ( self.elem_size_in_bytes * self.num_layers @@ -297,15 +297,12 @@ def _init_logical_caches(self): blocks.append(cache_block) return blocks - def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]: + def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]: """Initialize the physical cache on the device. For each layer of the model, we allocate two tensors for key and value respectively, - with shape of [num_blocks, num_kv_heads, head_size, block_size] + with shape of [num_blocks, num_kv_heads, block_size, head_size] """ - alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size) - # TODO: Explore the performance when using difference shapes with kernel-related optimizations - # e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x] k_cache: List[torch.Tensor] = [] v_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index ead4be8b7cd8..e4dd02b6042e 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): lengths: key/value lengths block_tables """ - num_blocks, num_heads, head_size, block_size = cache.shape + num_blocks, num_heads, block_size, head_size = cache.shape bsz, max_blocks_per_seq = block_tables.shape needed_blocks = (lengths + block_size - 1) // block_size @@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): block_num = needed_blocks[i] token_id = 0 for block_idx in range(block_num - 1): - cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0) + cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2) token_id += block_size - cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute( - 1, 2, 0 + cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute( + 1, 0, 2 ) elif type == "decoding": assert source.size(1) == 1, "seq_len should be equal to 1 when decoding." source = source.squeeze(1) slot_idx = (lengths + block_size - 1) % block_size for i in range(bsz): - cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i] + cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i] return cache @@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0): """ Func: convert key/value cache for calculation - Args: cache: shape [num_blocks, num_heads, head_size, block_size] + Args: cache: shape [num_blocks, num_heads, block_size, head_size] lengths: key/value length block_tables pad_id: padded_id """ - num_blocks, num_heads, head_size, block_size = cache.shape + num_blocks, num_heads, block_size, head_size = cache.shape needed_blocks = (lengths + block_size - 1) // block_size num_remaing_tokens = lengths % block_size @@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0): for i in range(bsz): _cache = torch.cat( ( - cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size), - cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1), + cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size), + cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2), ), dim=0, ) @@ -127,7 +127,7 @@ def nopad_context_forward( q: torch.Tensor, # [num_tokens, num_heads, head_size] k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] v: torch.Tensor, - k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] v_cache: torch.Tensor, context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] @@ -142,7 +142,7 @@ def nopad_context_forward( assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" num_kv_groups = num_heads // num_kv_heads - block_size = k_cache.shape[-1] + block_size = k_cache.size(-2) bsz, max_blocks_per_sequence = block_tables.shape max_seq_len = max_blocks_per_sequence * block_size assert q.shape[-1] == k.shape[-1] == v.shape[-1] @@ -196,7 +196,7 @@ def pad_context_forward( q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] v: torch.Tensor, - k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] v_cache: torch.Tensor, context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] @@ -207,7 +207,7 @@ def pad_context_forward( num_kv_heads = k.shape[-2] assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" num_kv_groups = num_heads // num_kv_heads - block_size = k_cache.shape[-1] + block_size = k_cache.size(-2) assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] block_tables.shape[-1] * block_size @@ -254,7 +254,7 @@ def pad_decoding_forward( q: torch.Tensor, # [bsz, 1, num_heads, head_size] k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] v: torch.Tensor, - k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] v_cache: torch.Tensor, lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 3a81a97f7a2e..569c5f05a05c 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -171,7 +171,7 @@ def llama_attn_forward( rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - _, _, _, block_size = k_cache.shape + block_size = k_cache.size(-2) if is_prompts: attn_output = context_attention_unpadded( diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py index fb66360f5a6d..63a8d367393a 100644 --- a/colossalai/inference/modeling/models/padding_llama.py +++ b/colossalai/inference/modeling/models/padding_llama.py @@ -226,7 +226,7 @@ def llama_attn_forward( rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - _, _, _, block_size = k_cache.shape + block_size = k_cache.size(-2) if is_prompts: attn_output = context_attention_unpadded( diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 3ef43cb83dd4..68baffd53d2b 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel( stride_od, stride_cacheb, stride_cacheh, - stride_cached, stride_cachebs, + stride_cached, stride_bts, stride_btb, context_lengths, @@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel( # Copy k to corresponding cache block offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt - k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0) + offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt + k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0) offsets_kcachebs = tl.arange(0, BLOCK_SIZE) offsets_kcache = ( KCache + offset_kvcache - + offsets_dmodel[:, None] * stride_cached - + offsets_kcachebs[None, :] * stride_cachebs + + offsets_dmodel[None, :] * stride_cached + + offsets_kcachebs[:, None] * stride_cachebs ) - tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) # Copy v to corresponding cache block offsets_vd = offsets_dmodel offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) - offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd - v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0) + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0) offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here offsets_vcache = ( VCache + offset_kvcache - + offsets_vcachebs[:, None] * stride_cachebs - + offsets_dmodel[None, :] * stride_cached + + offsets_vcachebs[None, :] * stride_cachebs + + offsets_dmodel[:, None] * stride_cached ) - tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) return diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 6b3ed2999c84..4bba2450321b 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -10,8 +10,8 @@ @triton.jit def _flash_decoding_fwd_kernel( Q, # [batch_size, head_num, q_len(1), head_dim] - KCache, # [num_blocks, num_kv_heads, head_dim, block_size] - VCache, # [num_blocks, num_kv_heads, head_dim, block_size] + KCache, # [num_blocks, num_kv_heads, block_size, head_dim] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] block_tables, # [batch_size, max_blocks_per_sequence] mid_o, # [batch_size, head_num, kv_split_num, head_dim] mid_o_lse, # [batch_size, head_num, kv_split_num] @@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel( stride_qd, stride_cacheb, stride_cacheh, - stride_cached, stride_cachebs, + stride_cached, stride_bts, stride_btb, stride_mid_ot, @@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel( K_block_ptr = tl.make_block_ptr( base=KCache + offset_kvcache, - shape=(HEAD_DIM, cur_occupied_size), - strides=(stride_cached, stride_cachebs), + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_cachebs, stride_cached), offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_SIZE), + block_shape=(BLOCK_SIZE, HEAD_DIM), order=(0, 1), ) V_block_ptr = tl.make_block_ptr( base=VCache + offset_kvcache, - shape=(HEAD_DIM, cur_occupied_size), - strides=(stride_cached, stride_cachebs), + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_cachebs, stride_cached), offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_SIZE), + block_shape=(BLOCK_SIZE, HEAD_DIM), order=(0, 1), ) k_cur_block = tl.load(K_block_ptr) @@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel( # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. # Refer to https://github.com/openai/triton/discussions/895 - S_ij += tl.sum(q[:, None] * k_cur_block, 0) + S_ij += tl.sum(q[None, :] * k_cur_block, 1) S_ij *= sm_scale S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf")) @@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel( p_ij_hat = tl.exp(S_ij) l = tl.sum(p_ij_hat, 0) p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) - acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1) + acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) acc = acc / l offsets_mid_o = ( @@ -206,8 +206,8 @@ def flash_decoding_attention( Args: q (torch.Tensor): [bsz, num_heads, head_dim] - k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] + v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] kv_seq_len (torch.Tensor): [batch_size] records the (kv) sequence lengths incorporating past kv sequence lengths. block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] @@ -230,13 +230,13 @@ def flash_decoding_attention( assert head_dim in {32, 64, 128, 256} assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( f"Got incompatible batch size (number of seqs):\n" - f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, " + f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, " f"batch size {bsz}" ) - assert k_cache.size(-1) == v_cache.size(-1) == block_size, ( + assert k_cache.size(-2) == v_cache.size(-2) == block_size, ( f"Got incompatible block size on kv caches:\n" - f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, " - f"v_cache block_size {v_cache.size(-1)}" + f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, " + f"v_cache block_size {v_cache.size(-2)}" ) # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 74f20c33b10f..1aaeb6830e7c 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel( stride_kd, stride_cacheb, stride_cacheh, - stride_cached, stride_cachebs, + stride_cached, stride_bts, stride_btb, block_size, @@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel( last_bt_block_idx = past_kv_seq_len // block_size block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) - offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs + offsets_in_last_block = past_kv_seq_len % block_size offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd kv = tl.load(KV + offsets_kv) offsets_kvcache = ( block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + offsets_in_last_block * stride_cachebs + offsets_dmodel * stride_cached - + offsets_in_last_block ) tl.store(KVCache + offsets_kvcache, kv) return @@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache( """ Copy keys or values to the blocked key/value cache during decoding stage. - Parameters: - - k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. - - k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache. - - kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. - - block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. + Args: + k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. + kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. """ - assert k.size(-1) == k_cache.size(-2), "Incompatible head dim" + assert k.size(-1) == k_cache.size(-1), "Incompatible head dim" assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." - if k.dim() == 4: - assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)" - bsz, _, num_kv_heads, head_dim = k.shape - # [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim] - k = k.squeeze(dim=1) - elif k.dim() == 3: - bsz, num_kv_heads, head_dim = k.shape - else: - raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.") + + k = k.squeeze(1) if k.dim() == 4 else k + assert k.dim() == 3, f"Incompatible k dim {k.dim()}" + bsz, num_kv_heads, head_dim = k.shape assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( f"Got incompatible batch size (number of seqs):\n" @@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache( ) # Modify if the shape of kv cahce is changed. - block_size = k_cache.size(-1) + block_size = k_cache.size(-2) num_warps = 8 if head_dim > 128 else 4 diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index 9f7daa9a5b25..a2051f220790 100755 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -93,7 +93,7 @@ def check_cache_manager(test_config): assert len(cache_manager._cache_blocks) == num_blocks key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers assert len(key_caches) == num_layers - expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size) + expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size) assert key_caches[0].shape == expected_kv_shape k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0) expected_kv_block_shape = expected_kv_shape[1:] diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py index b4754fdea1d3..1091370ceba9 100644 --- a/tests/test_infer/test_models/test_attention.py +++ b/tests/test_infer/test_models/test_attention.py @@ -1,20 +1,17 @@ -import pytest import torch from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb -import colossalai from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache -from colossalai.testing import rerun_if_address_is_in_use, spawn def test_copy_to_cache(): key = torch.ones((2, 11, 3, 3)) key[0, 9, :, :] = 0 key[1, -2:, :, :] = 0 - cache = torch.zeros(8, 3, 3, 8) + cache = torch.zeros(8, 3, 8, 3) block_tables = torch.tensor([[0, 1], [2, 3]]) lengths = torch.tensor([9, 8]) cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill") @@ -28,7 +25,7 @@ def test_copy_to_cache(): def test_convert_kvcache(): - cache = torch.ones(8, 3, 3, 8) + cache = torch.ones(8, 3, 8, 3) key = torch.ones(2, 1, 3, 3) + 1 lengths = torch.tensor([10, 9]) block_tables = torch.tensor([[0, 1], [2, 3]]) @@ -43,8 +40,8 @@ def test_context_attention(): """ attn = PagedAttention() q = k = v = torch.randn(8, 4, 4) - k_cache = torch.empty(8, 4, 4, 8) - v_cache = torch.empty(8, 4, 4, 8) + k_cache = torch.empty(8, 4, 8, 4) + v_cache = torch.empty(8, 4, 8, 4) context_lengths = torch.tensor( [ 8, @@ -136,23 +133,8 @@ def test_decoding_attention(): assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) -def check_attention_layer(): +if __name__ == "__main__": test_copy_to_cache() test_convert_kvcache() test_context_attention() test_decoding_attention() - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") - check_attention_layer() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_attention_layer(): - spawn(run_dist, 1) - - -if __name__ == "__main__": - test_attention_layer() diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index 31bd4812a8b5..7c3bc5ca6871 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -106,6 +106,40 @@ def mock_alloc_block_table_and_kvcache( return block_tables +def mock_alloc_block_table_and_kvcache_v2( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2) + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2) + k_cache[block_id, :, :allocated_locs, :] = k_block + v_cache[block_id, :, :allocated_locs, :] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None: # Allocate 1 token on the block table for each seqs in block tables. # It won't change provided context_lengths. @@ -146,6 +180,22 @@ def generate_caches_and_block_tables( return k_cache, v_cache, block_tables +def generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache_v2( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + def convert_kv_unpad_to_padded( k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int ) -> torch.Tensor: diff --git a/tests/test_infer_ops/triton/test_context_attn_unpad.py b/tests/test_infer_ops/triton/test_context_attn_unpad.py index 4498b8519c3d..0a3ede5555de 100644 --- a/tests/test_infer_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer_ops/triton/test_context_attn_unpad.py @@ -6,7 +6,7 @@ from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device -from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref +from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref try: import triton # noqa @@ -93,7 +93,7 @@ def test_context_attention( q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) q_unpad = q_unpad.contiguous() - k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) @@ -148,7 +148,6 @@ def bench_kernel( num_kv_heads = num_attn_heads // kv_group_num assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." - block_size * max_num_blocks_per_seq dtype = torch.float16 device = get_current_device() @@ -162,7 +161,7 @@ def bench_kernel( qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) q_unpad = q_unpad.contiguous() - k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py index 8d1a5a36c21e..a49ee3146132 100644 --- a/tests/test_infer_ops/triton/test_decoding_attn.py +++ b/tests/test_infer_ops/triton/test_decoding_attn.py @@ -6,7 +6,7 @@ from colossalai.utils import get_current_device from tests.test_infer_ops.triton.kernel_utils import ( convert_kv_unpad_to_padded, - generate_caches_and_block_tables, + generate_caches_and_block_tables_v2, prepare_padding_mask, torch_attn_ref, ) @@ -38,6 +38,9 @@ def prepare_data( ): # Use the provided maximum sequence length for each sequence when testing with teh same context length, # otherwise generate random context lengths. + # returns + # q [bsz, num_attn_heads, q_len, head_dim] + # k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim] kv_lengths = ( torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) if same_context_len @@ -83,7 +86,7 @@ def test_flash_decoding( q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device ) - k_cache, v_cache, block_tables = generate_caches_and_block_tables( + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) @@ -180,7 +183,7 @@ def bench_kernel( ) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) if provider == "triton": - k_cache, v_cache, block_tables = generate_caches_and_block_tables( + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) diff --git a/tests/test_infer_ops/triton/test_kvcache_copy.py b/tests/test_infer_ops/triton/test_kvcache_copy.py index c2ccb5ef5f7b..3b0a0f76598e 100644 --- a/tests/test_infer_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer_ops/triton/test_kvcache_copy.py @@ -5,7 +5,7 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.utils import get_current_device -from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, mock_alloc_single_token +from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token try: import triton # noqa @@ -17,6 +17,8 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +HEAD_DIM = 128 + def prepare_data( bsz, @@ -29,31 +31,27 @@ def prepare_data( device, dtype=torch.float16, ): - if same_context_len: - # past_kv_seq_lengths in this test records the previous kv seq len - # (not incorporating the current input whose seq len is 1) - past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) - else: - past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) + # past_kv_seq_lengths in this test records the previous kv seq len + # (not incorporating the current input whose seq len is 1) + past_kv_seq_lengths = ( + torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) + if same_context_len + else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) + ) num_tokens = torch.sum(past_kv_seq_lengths).item() kv_size = (num_tokens, 2 * num_kv_heads, head_dim) - kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) - k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2) - - cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) - k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) - v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) - # Mock allocation on block tables as well as blocked kv caches - block_tables = mock_alloc_block_table_and_kvcache( - k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) + + k_cache, _, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device ) block_tables = block_tables.to(device=device) new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) # mock allocating blocks for the new k/v and update block tables mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) - # kv seq len = past kv seq len + seq len (1 during decoding stage) kv_seq_lengths = past_kv_seq_lengths + 1 @@ -78,7 +76,6 @@ def test_copy_kv_to_caches( torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() - head_dim = 128 max_seq_len = block_size * max_num_blocks_per_seq dtype = torch.float16 device = get_current_device() @@ -86,7 +83,7 @@ def test_copy_kv_to_caches( new_k, k_cache, kv_seq_lengths, block_tables = prepare_data( bsz, num_kv_heads, - head_dim, + HEAD_DIM, block_size, max_num_blocks_per_seq, same_context_len, @@ -94,20 +91,28 @@ def test_copy_kv_to_caches( device=device, dtype=dtype, ) + # k_cache_torch = k_cache.clone().detach() + # copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding") copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables) - for seq_i in range(bsz): - ki = new_k[seq_i] - ki = ki.squeeze() - past_kv_seq_len = kv_seq_lengths[seq_i] - 1 - target_block_id = block_tables[seq_i, past_kv_seq_len // block_size] - offsets_in_block = past_kv_seq_len % block_size - target = k_cache[target_block_id, :, :, offsets_in_block] - orig = new_k[seq_i].squeeze(dim=0) - assert torch.equal(orig, target) + past_kv_seq_len = kv_seq_lengths - 1 + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + target = k_cache[target_block_ids, :, offsets_in_block, :] + source = new_k.squeeze() + + assert target.shape == source.shape + assert torch.equal(target, source) + # target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :] + # assert target_torch.shape == source.shape + # assert torch.equal(target_torch, source) BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 configs = [ triton.testing.Benchmark( x_names=["KV_SEQ_LEN"], @@ -133,10 +138,6 @@ def benchmark_kvcache_copy( num_kv_heads: int, same_context_len: bool, ): - warmup = 10 - rep = 100 - - head_dim = 128 dtype = torch.float16 device = get_current_device() @@ -145,7 +146,7 @@ def benchmark_kvcache_copy( new_k, k_cache, context_lengths, block_tables = prepare_data( bsz, num_kv_heads, - head_dim, + HEAD_DIM, block_size, max_seq_len // block_size, same_context_len, @@ -154,15 +155,14 @@ def benchmark_kvcache_copy( dtype=dtype, ) + quantiles = [0.5, 0.2, 0.8] if provider == "torch_copy_func": fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") - elif provider == "triton_copy_func": + if provider == "triton_copy_func": fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables) - else: - raise ValueError("Undefined provider.") - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - return ms + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + return ms, min_ms, max_ms if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_rmsnorm_triton.py b/tests/test_infer_ops/triton/test_rmsnorm_triton.py index 7cc69657cd85..cc0ef292ffab 100644 --- a/tests/test_infer_ops/triton/test_rmsnorm_triton.py +++ b/tests/test_infer_ops/triton/test_rmsnorm_triton.py @@ -3,7 +3,6 @@ import triton from packaging import version from transformers.models.llama.modeling_llama import LlamaRMSNorm -from vllm.model_executor.layers.layernorm import RMSNorm from colossalai.kernel.triton import rms_layernorm from colossalai.testing.utils import parameterize @@ -36,7 +35,8 @@ def test_layer_norm(M, N): y_triton = rms_layernorm(x, weight, eps=eps) y_llama = rms_norm.forward(x).to(dtype) - assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5) + assert y_triton.shape == y_llama.shape + assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3) # Triton benchmark plot attributions @@ -45,8 +45,8 @@ def test_layer_norm(M, N): x_names=["SEQUENCE_TOTAL"], x_vals=[i for i in range(128, 1025, 128)], line_arg="provider", - line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"], - line_names=["vllm_rms_layernorm", "triton_rms_layernorm"], + line_vals=["torch_rms_layernorm", "triton_rms_layernorm"], + line_names=["torch_rms_layernorm", "triton_rms_layernorm"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", plot_name=f"RMSNorm benchmarking results", @@ -69,10 +69,10 @@ def benchmark_rms_layernorm( x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) w_shape = (x_shape[-1],) weight = torch.ones(w_shape, dtype=dtype, device="cuda") - vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") + torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - if provider == "vllm_rms_layernorm": - fn = lambda: vllm_norm(x) + if provider == "torch_rms_layernorm": + fn = lambda: torch_norm(x) elif provider == "triton_rms_layernorm": fn = lambda: rms_layernorm(x, weight, eps=eps) else: