forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] add KV cache manager for llama & bloom inference (hpcaitech…
…#4495) * add kv cache memory manager * add stateinfo during inference * format * format * rename file * add kv cache test * revise on BatchInferState * file dir change
- Loading branch information
1 parent
08d137b
commit 57b5f25
Showing
4 changed files
with
232 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .batch_infer_state import BatchInferState | ||
from .kvcache_manager import MemoryManager | ||
|
||
__all__ = ['BatchInferState', 'MemoryManager'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later | ||
from dataclasses import dataclass | ||
from typing import Any | ||
|
||
import torch | ||
|
||
from .kvcache_manager import MemoryManager | ||
|
||
|
||
@dataclass | ||
class BatchInferState: | ||
r""" | ||
Information to be passed and used for a batch of inputs during | ||
a single model forward | ||
""" | ||
batch_size: int | ||
max_len_in_batch: int | ||
|
||
cache_manager: MemoryManager = None | ||
|
||
block_loc: torch.Tensor = None | ||
start_loc: torch.Tensor = None | ||
seq_len: torch.Tensor = None | ||
|
||
is_context_stage: bool = False | ||
context_mem_index: torch.Tensor = None | ||
decode_is_contiguous: bool = None | ||
decode_mem_start: int = None | ||
decode_mem_end: int = None | ||
decode_mem_index: torch.Tensor = None | ||
decode_layer_id: int = None | ||
|
||
device: torch.device = torch.device('cuda') | ||
|
||
@property | ||
def total_token_num(self): | ||
return self.batch_size * self.max_len_in_batch | ||
|
||
def set_cache_manager(self, manager: MemoryManager): | ||
self.cache_manager = manager | ||
|
||
@staticmethod | ||
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, | ||
alloc_mem_index: torch.Tensor): | ||
""" in-place update block loc mapping based on the sequence length of the inputs in current bath""" | ||
start_index = 0 | ||
seq_len_numpy = seq_len.cpu().numpy() | ||
for i, cur_seq_len in enumerate(seq_len_numpy): | ||
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + | ||
cur_seq_len] | ||
start_index += cur_seq_len | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Adapted from lightllm/common/mem_manager.py | ||
# of the ModelTC/lightllm GitHub repository | ||
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py | ||
# | ||
# Copyright 2023 ModelTC Team | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
from colossalai.logging import get_dist_logger | ||
|
||
|
||
class MemoryManager: | ||
r""" | ||
Manage token block indexes and allocate physical memory for key and value cache | ||
Args: | ||
size: maximum token number used as the size of key and value buffer | ||
dtype: data type of cached key and value | ||
head_num: number of heads the memory manager is responsible for | ||
head_dim: embedded size per head | ||
layer_num: the number of layers in the model | ||
device: device used to store the key and value cache | ||
""" | ||
|
||
def __init__(self, | ||
size: int, | ||
dtype: torch.dtype, | ||
head_num: int, | ||
head_dim: int, | ||
layer_num: int, | ||
device: torch.device = torch.device('cuda')): | ||
self.logger = get_dist_logger(__name__) | ||
self.available_size = size | ||
self.past_key_values_length = 0 | ||
self._init_mem_states(size, device) | ||
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) | ||
|
||
def _init_mem_states(self, size, device): | ||
""" Initialize tensors used to manage memory states """ | ||
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) | ||
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) | ||
self.indexes = torch.arange(0, size, dtype=torch.long, device=device) | ||
|
||
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): | ||
""" Initialize key buffer and value buffer on specified device """ | ||
self.key_buffer = [ | ||
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) | ||
] | ||
self.value_buffer = [ | ||
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) | ||
] | ||
|
||
@torch.no_grad() | ||
def alloc(self, required_size): | ||
""" allocate space of required_size by providing indexes representing available physical spaces """ | ||
if required_size > self.available_size: | ||
self.logger.warning(f"No enough cache: required_size {required_size} " | ||
f"left_size {self.available_size}") | ||
return None | ||
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) | ||
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) | ||
select_index = self.indexes[select_index] | ||
self.mem_state[select_index] = 0 | ||
self.available_size -= len(select_index) | ||
return select_index | ||
|
||
@torch.no_grad() | ||
def alloc_contiguous(self, required_size): | ||
""" allocate contiguous space of required_size """ | ||
if required_size > self.available_size: | ||
self.logger.warning(f"No enough cache: required_size {required_size} " | ||
f"left_size {self.available_size}") | ||
return None | ||
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) | ||
sum_size = len(self.mem_cum_sum) | ||
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + | ||
1] + self.mem_state[0:sum_size - | ||
required_size + 1] | ||
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] | ||
if can_used_loc.shape[0] == 0: | ||
self.logger.info(f"No enough contiguous cache: required_size {required_size} " | ||
f"left_size {self.available_size}") | ||
return None | ||
start_loc = can_used_loc[0] | ||
select_index = self.indexes[start_loc:start_loc + required_size] | ||
self.mem_state[select_index] = 0 | ||
self.available_size -= len(select_index) | ||
start = start_loc.item() | ||
end = start + required_size | ||
return select_index, start, end | ||
|
||
@torch.no_grad() | ||
def free(self, free_index): | ||
""" free memory by updating memory states based on given indexes """ | ||
self.available_size += free_index.shape[0] | ||
self.mem_state[free_index] = 1 | ||
|
||
@torch.no_grad() | ||
def free_all(self): | ||
""" free all memory by updating memory states """ | ||
self.available_size = len(self.mem_state) | ||
self.mem_state[:] = 1 | ||
self.past_key_values_length = 0 | ||
self.logger.info("freed all space of memory manager") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
from colossalai.logging import disable_existing_loggers | ||
from colossalai.shardformer.inference import MemoryManager | ||
from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
|
||
BATCH_SIZE = 4 | ||
INPUT_LEN = 16 | ||
OUTPUT_LEN = 8 | ||
LAYER_NUM = 4 | ||
HEAD_NUM = 32 | ||
HEAD_DIM = 128 | ||
|
||
|
||
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): | ||
os.environ['RANK'] = str(rank) | ||
os.environ['LOCAL_RANK'] = str(rank) | ||
os.environ['WORLD_SIZE'] = str(world_size) | ||
os.environ['MASTER_ADDR'] = 'localhost' | ||
os.environ['MASTER_PORT'] = str(port) | ||
disable_existing_loggers() | ||
|
||
size = batch_size * (input_len + output_len) | ||
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank) | ||
key_buffers = kvcache_manager.key_buffer | ||
value_buffers = kvcache_manager.value_buffer | ||
assert len(key_buffers) == len(value_buffers) == layer_num | ||
assert key_buffers[0].shape == value_buffers[0].shape | ||
# required size exceeds the maximum allocated size | ||
invalid_locs = kvcache_manager.alloc_contiguous(size + 1) | ||
assert invalid_locs is None | ||
# for prefill stage, allocation via alloc and alloc_contiguous should be the same | ||
total_token_prefill = batch_size * input_len | ||
prefill_locs = kvcache_manager.alloc(total_token_prefill) | ||
kvcache_manager.free_all() | ||
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0] | ||
assert torch.equal(prefill_locs, prefill_locs_contiguous) | ||
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill | ||
kvcache_manager.alloc_contiguous(batch_size) | ||
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) | ||
|
||
|
||
@pytest.mark.dist | ||
@rerun_if_address_is_in_use() | ||
def test_cache_manager_dist(): | ||
spawn(create_cache_manager, | ||
4, | ||
batch_size=BATCH_SIZE, | ||
input_len=INPUT_LEN, | ||
output_len=OUTPUT_LEN, | ||
layer_num=LAYER_NUM, | ||
head_num=HEAD_NUM, | ||
head_dim=HEAD_DIM) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_cache_manager_dist() |