Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support offloading KV cache to CPU #10874

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
164 changes: 164 additions & 0 deletions benchmarks/benchmark_long_document_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Benchmark the efficiency of prefix caching.

This script allows you to benchmark the performance of
a model with and without prefix caching using either fixed prompts
or prompts sampled from the ShareGPT dataset.

Fixed example usage:
python benchmark_prefix_caching.py \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \
--num-prompts 1 \
--repeat-count 100

ShareGPT example usage:
# This command samples 20 prompts with input lengths
# between 128 and 256 tokens from the ShareGPT dataset,
# then replicates each prompt 5 times.
python benchmark_prefix_caching.py \
--model meta-llama/Llama-2-7b-chat-hf \
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
--enable-prefix-caching \
--num-prompts 20 \
--repeat-count 5 \
--input-length-range 128:256
"""

import random
import time

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser


def test_long_document_qa(llm=None, sampling_params=None, prompts=None):

start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"cost time {end_time - start_time}")


def repeat_prompts(prompts, repeat_count):
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts


def main(args):

random.seed(args.seed)

# append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
str(i) + ' '.join(['hi'] * args.document_length)
for i in range(args.num_documents)
]

preemption_mode = ""
if args.block_allocator == "CpuOffloadingBlockAllocator":
preemption_mode = "recompute"
else:
preemption_mode = "swap"

llm = LLM(model=args.model,
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching,
block_allocator=args.block_allocator,
preemption_mode=preemption_mode,
swap_space=args.cpu_memory_gb,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=30000)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

prompts = repeat_prompts(prompts, args.repeat_count)

print("------warm up------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)

print("------start generating------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description=
'Benchmark the performance with or without automatic prefix caching.')
parser.add_argument(
'--model',
type=str,
# this test aims to test long document QA capability,
# so we use llama 3.1 8B as it can process long context
default='meta-llama/Llama-3.1-8B')
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--repeat-count',
type=int,
default=2,
help='Number of times to repeat each prompt')
parser.add_argument(
'--document-length',
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20010,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument('--num-documents',
type=int,
default=8,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument("--seed",
type=int,
default=0,
help='Random seed for reproducibility')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.5,
help='GPU memory utilization for vLLM. Should be a '
'float point number ranging from 0 to 1. For this '
'test please use a small value so that the GPU '
'cannot hold all KV caches of all documents, '
'and the effect of CPU offloading can be tested.')
parser.add_argument(
'--cpu-memory-gb',
type=float,
default=1,
help="The amount of CPU memory (GB) that is used by vLLM. Not very "
"useful for CpuGpuBlockAllocator, but useful for "
"CpuOffloadingBlockAllocator to have more CPU KV cache space")
parser.add_argument(
'--block-allocator',
type=str,
default='CpuGpuBlockAllocator',
choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'],
help='The block allocator that vLLM uses. Currently'
' can be CpuGpuBlockAllocator (the default) and '
'CpuOffloadingBlockAllocator (experimental) that '
'supports offloading the KV cache to CPU . '
'When using CpuOffloadingBlockAllocator, the '
'preemption mode must be recompute.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,4 @@ def main(args):

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)
main(args)
96 changes: 94 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

#include <cstdio>
#include <algorithm>
#include <cassert>
#include <map>
Expand All @@ -21,8 +22,64 @@
typedef __hip_bfloat16 __nv_bfloat16;
#endif

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
namespace vllm {

template <typename T, typename ACC_T>
__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src,
ACC_T src_to_dst, const int num_pages,
const int num_elements_per_page) {
const int64_t srcPageIdx = src_to_dst[blockIdx.x][0];
const int64_t dstPageIdx = src_to_dst[blockIdx.x][1];

const int64_t srcPageOffset = srcPageIdx * num_elements_per_page;
const int64_t dstPageOffset = dstPageIdx * num_elements_per_page;

for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) {
dst[dstPageOffset + i] = src[srcPageOffset + i];
}
}

} // namespace vllm

template <int DTYPE_LEN, typename DTYPE>
void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src,
const torch::Tensor& block_mapping,
const int num_blocks,
const int block_size_in_bytes) {
c10::cuda::CUDAGuard device_guard(block_mapping.device());
auto block_mapping_accessor =
block_mapping.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>();

int num_threads = 1024;
int grid_size = num_blocks;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::paged_copy<<<grid_size, num_threads, 0, stream>>>(
dst, src, block_mapping_accessor, num_blocks,
block_size_in_bytes / DTYPE_LEN);
}

template <typename T>
T* get_kernel_ptr(torch::Tensor& tensor) {
// Get the kernel-accessible pointer of the given type T
// Returns NULL if the tensor is on CPU and non-pinned
torch::Device device = tensor.device();
if (device.is_cuda()) {
return static_cast<T*>(tensor.data_ptr());
} else if (device.is_cpu() && tensor.is_pinned()) {
T* ptr;
cudaHostGetDevicePointer((void**)&ptr, static_cast<T*>(tensor.data_ptr()),
0);
return ptr;
} else if (device.is_cpu()) {
return NULL;
} else {
TORCH_CHECK(false, "Invalid device");
}
}

void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
Expand Down Expand Up @@ -62,6 +119,41 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
int64_t* src_ptr = get_kernel_ptr<int64_t>(src);
int64_t* dst_ptr = get_kernel_ptr<int64_t>(dst);
if (src_ptr == NULL || dst_ptr == NULL) {
// fall back to the slow implementation
swap_blocks_slow(src, dst, block_mapping.cpu());
} else {
// Check the device
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
torch::Device block_mapping_device = block_mapping.device();
TORCH_CHECK(block_mapping_device.is_cuda(), "block_mapping must be on GPU");
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
}
if (src_device.is_cuda()) {
TORCH_CHECK(src_device.index() == block_mapping_device.index(),
"src and block_mapping must be on the same GPU");
}
if (dst_device.is_cuda()) {
TORCH_CHECK(dst_device.index() == block_mapping_device.index(),
"src and block_mapping must be on the same GPU");
}

const int64_t num_blocks = block_mapping.size(0);
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();

launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr,
block_mapping, num_blocks,
block_size_in_bytes);
}
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
Loading
Loading