Skip to content

Commit

Permalink
[feature] add KV cache manager for llama & bloom inference (hpcaitech…
Browse files Browse the repository at this point in the history
…#4495)

* add kv cache memory manager

* add stateinfo during inference

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* file dir change
  • Loading branch information
yuanheng-zhao authored and tiandiao123 committed Sep 7, 2023
1 parent 08d137b commit 57b5f25
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 0 deletions.
4 changes: 4 additions & 0 deletions colossalai/shardformer/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager

__all__ = ['BatchInferState', 'MemoryManager']
52 changes: 52 additions & 0 deletions colossalai/shardformer/inference/batch_infer_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
from typing import Any

import torch

from .kvcache_manager import MemoryManager


@dataclass
class BatchInferState:
r"""
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int

cache_manager: MemoryManager = None

block_loc: torch.Tensor = None
start_loc: torch.Tensor = None
seq_len: torch.Tensor = None

is_context_stage: bool = False
context_mem_index: torch.Tensor = None
decode_is_contiguous: bool = None
decode_mem_start: int = None
decode_mem_end: int = None
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None

device: torch.device = torch.device('cuda')

@property
def total_token_num(self):
return self.batch_size * self.max_len_in_batch

def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager

@staticmethod
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
alloc_mem_index: torch.Tensor):
""" in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
cur_seq_len]
start_index += cur_seq_len
return
116 changes: 116 additions & 0 deletions colossalai/shardformer/inference/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Adapted from lightllm/common/mem_manager.py
# of the ModelTC/lightllm GitHub repository
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
#
# Copyright 2023 ModelTC Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from colossalai.logging import get_dist_logger


class MemoryManager:
r"""
Manage token block indexes and allocate physical memory for key and value cache
Args:
size: maximum token number used as the size of key and value buffer
dtype: data type of cached key and value
head_num: number of heads the memory manager is responsible for
head_dim: embedded size per head
layer_num: the number of layers in the model
device: device used to store the key and value cache
"""

def __init__(self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device('cuda')):
self.logger = get_dist_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
self._init_mem_states(size, device)
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)

def _init_mem_states(self, size, device):
""" Initialize tensors used to manage memory states """
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)

def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
""" Initialize key buffer and value buffer on specified device """
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
self.value_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]

@torch.no_grad()
def alloc(self, required_size):
""" allocate space of required_size by providing indexes representing available physical spaces """
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
select_index = self.indexes[select_index]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
return select_index

@torch.no_grad()
def alloc_contiguous(self, required_size):
""" allocate contiguous space of required_size """
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
1] + self.mem_state[0:sum_size -
required_size + 1]
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
self.logger.info(f"No enough contiguous cache: required_size {required_size} "
f"left_size {self.available_size}")
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc:start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
end = start + required_size
return select_index, start, end

@torch.no_grad()
def free(self, free_index):
""" free memory by updating memory states based on given indexes """
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1

@torch.no_grad()
def free_all(self):
""" free all memory by updating memory states """
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0
self.logger.info("freed all space of memory manager")
60 changes: 60 additions & 0 deletions tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os

import pytest
import torch

from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.inference import MemoryManager
from colossalai.testing import rerun_if_address_is_in_use, spawn

BATCH_SIZE = 4
INPUT_LEN = 16
OUTPUT_LEN = 8
LAYER_NUM = 4
HEAD_NUM = 32
HEAD_DIM = 128


def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
disable_existing_loggers()

size = batch_size * (input_len + output_len)
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
key_buffers = kvcache_manager.key_buffer
value_buffers = kvcache_manager.value_buffer
assert len(key_buffers) == len(value_buffers) == layer_num
assert key_buffers[0].shape == value_buffers[0].shape
# required size exceeds the maximum allocated size
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
assert invalid_locs is None
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
total_token_prefill = batch_size * input_len
prefill_locs = kvcache_manager.alloc(total_token_prefill)
kvcache_manager.free_all()
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
assert torch.equal(prefill_locs, prefill_locs_contiguous)
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
spawn(create_cache_manager,
4,
batch_size=BATCH_SIZE,
input_len=INPUT_LEN,
output_len=OUTPUT_LEN,
layer_num=LAYER_NUM,
head_num=HEAD_NUM,
head_dim=HEAD_DIM)


if __name__ == '__main__':
test_cache_manager_dist()

0 comments on commit 57b5f25

Please sign in to comment.