Skip to content

Commit

Permalink
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
Browse files Browse the repository at this point in the history
* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
  • Loading branch information
yuanheng-zhao authored Jan 30, 2024
1 parent e8f0642 commit 5f98a9d
Show file tree
Hide file tree
Showing 14 changed files with 172 additions and 146 deletions.
11 changes: 4 additions & 7 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
if verbose:
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches()
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
Expand Down Expand Up @@ -297,15 +297,12 @@ def _init_logical_caches(self):
blocks.append(cache_block)
return blocks

def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, head_size, block_size]
with shape of [num_blocks, num_kv_heads, block_size, head_size]
"""
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
Expand Down
28 changes: 14 additions & 14 deletions colossalai/inference/modeling/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
lengths: key/value lengths
block_tables
"""
num_blocks, num_heads, head_size, block_size = cache.shape
num_blocks, num_heads, block_size, head_size = cache.shape
bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size

Expand All @@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
block_num = needed_blocks[i]
token_id = 0
for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2)
token_id += block_size
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
1, 2, 0
cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute(
1, 0, 2
)
elif type == "decoding":
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz):
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i]

return cache

Expand All @@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
Args: cache: shape [num_blocks, num_heads, block_size, head_size]
lengths: key/value length
block_tables
pad_id: padded_id
"""
num_blocks, num_heads, head_size, block_size = cache.shape
num_blocks, num_heads, block_size, head_size = cache.shape

needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = lengths % block_size
Expand All @@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
for i in range(bsz):
_cache = torch.cat(
(
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2),
),
dim=0,
)
Expand Down Expand Up @@ -127,7 +127,7 @@ def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
Expand All @@ -142,7 +142,7 @@ def nopad_context_forward(
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads

block_size = k_cache.shape[-1]
block_size = k_cache.size(-2)
bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
Expand Down Expand Up @@ -196,7 +196,7 @@ def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
Expand All @@ -207,7 +207,7 @@ def pad_context_forward(
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
block_size = k_cache.size(-2)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size

Expand Down Expand Up @@ -254,7 +254,7 @@ def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def llama_attn_forward(

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)

if is_prompts:
attn_output = context_attention_unpadded(
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/modeling/models/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def llama_attn_forward(

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)

if is_prompts:
attn_output = context_attention_unpadded(
Expand Down
22 changes: 11 additions & 11 deletions colossalai/kernel/triton/context_attn_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
stride_od,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
context_lengths,
Expand Down Expand Up @@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
# Copy k to corresponding cache block
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = (
KCache
+ offset_kvcache
+ offsets_dmodel[:, None] * stride_cached
+ offsets_kcachebs[None, :] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_kcachebs[:, None] * stride_cachebs
)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = (
VCache
+ offset_kvcache
+ offsets_vcachebs[:, None] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_vcachebs[None, :] * stride_cachebs
+ offsets_dmodel[:, None] * stride_cached
)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)

return

Expand Down
34 changes: 17 additions & 17 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
@triton.jit
def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, q_len(1), head_dim]
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
Expand All @@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel(
stride_qd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
stride_mid_ot,
Expand Down Expand Up @@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel(

K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size),
strides=(stride_cached, stride_cachebs),
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size),
strides=(stride_cached, stride_cachebs),
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
k_cur_block = tl.load(K_block_ptr)
Expand All @@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel(
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
# Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[:, None] * k_cur_block, 0)
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))

Expand All @@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel(
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l

offsets_mid_o = (
Expand Down Expand Up @@ -206,8 +206,8 @@ def flash_decoding_attention(
Args:
q (torch.Tensor): [bsz, num_heads, head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
kv_seq_len (torch.Tensor): [batch_size]
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
Expand All @@ -230,13 +230,13 @@ def flash_decoding_attention(
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, "
f"batch size {bsz}"
)
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
assert k_cache.size(-2) == v_cache.size(-2) == block_size, (
f"Got incompatible block size on kv caches:\n"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
f"v_cache block_size {v_cache.size(-1)}"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, "
f"v_cache block_size {v_cache.size(-2)}"
)

# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
Expand Down
33 changes: 14 additions & 19 deletions colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
block_size,
Expand All @@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
+ offsets_in_last_block
)
tl.store(KVCache + offsets_kvcache, kv)
return
Expand All @@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
"""
Copy keys or values to the blocked key/value cache during decoding stage.
Parameters:
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
if k.dim() == 4:
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
bsz, _, num_kv_heads, head_dim = k.shape
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
k = k.squeeze(dim=1)
elif k.dim() == 3:
bsz, num_kv_heads, head_dim = k.shape
else:
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")

k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
bsz, num_kv_heads, head_dim = k.shape

assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
Expand All @@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
)

# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-1)
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_cache_manager(test_config):
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:]
Expand Down
Loading

0 comments on commit 5f98a9d

Please sign in to comment.