Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Efficient transmission for CPU prefix caching, based on PR#10874 #11099

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions benchmarks/benchmark_long_document_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""
Benchmark the efficiency of prefix caching.

This script allows you to benchmark the performance of
a model with prefix-caching or cpu-offloading using fixed prompts

Fixed example usage:
# This command run the vllm with 50GB CPU memory for offloading
# The workload samples 8 different prompts with a default input
# length of 20010 tokens, then replicates each prompt 2 times.
python benchmark_long_document_qa.py \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \
--block-allocator CpuOffloadingBlockAllocator \
--num-documents 8 \
--repeat-count 2 \
--cpu-memory-gb 50

Commandline arguments:

# Basic arguments
--model: The model to use for the benchmark.

--enable-prefix-caching: Enable prefix caching or not.

--block-allocator: The block allocator that vLLM uses.
- CpuGpuBlockAllocator: The default block allocator.
- CpuOffloadingBlockAllocator: The block allocator that supports
cpu offloading

--gpu-memory-utilization: GPU memory utilization for vLLM.

--cpu-memory-gb: The amount of CPU memory (GB) that is used by vLLM.
NOTE: CPU memory should be larger than GPU KV cache size when
using CpuOffloadingBlockAllocator.

# Workload-related arguments
--num-documents: The number of documents to sample prompts from.

--repeat-count: The number of times to repeat each prompt.

# Other functionality
--seed: Random seed for reproducibility.

--profile-swap-blocks: Profile the swap_blocks function in the custom ops.
"""

import random
import time

import torch

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser

execution_times = {}


def build_result_dict(start_time, end_time, *args):
total_time = end_time - start_time
length = -1
if len(args) > 1 and isinstance(args[1], torch.Tensor):
length = len(args[1])

return {
"start_time": start_time,
"total_time": total_time,
"swap_len": length
}


def timing_decorator(func):

def wrapper(*args, **kwargs):
global execution_times
torch.cuda.synchronize()
start_time = time.time() # Record the start time
result = func(*args, **kwargs) # Call the wrapped function
torch.cuda.synchronize()
end_time = time.time() # Record the end time
if func.__name__ not in execution_times:
execution_times[func.__name__] = []

res = build_result_dict(start_time, end_time, *args)
execution_times[func.__name__].append(res)
return result # Return the result of the original function

return wrapper


def process_timing_results():
global execution_times
for key in execution_times:
len_to_time = {}
len_to_count = {}
for item in execution_times[key]:
swap_len = item["swap_len"]
if swap_len not in len_to_time:
len_to_time[swap_len] = 0
len_to_time[swap_len] += item["total_time"]

if swap_len not in len_to_count:
len_to_count[swap_len] = 0
len_to_count[swap_len] += 1

for swap_len in len_to_time:
total_time = len_to_time[swap_len]
count = len_to_count[swap_len]
print(f"{key} on {swap_len} pages: "
f"{(count * swap_len) / total_time} pages per second")


def test_long_document_qa(llm=None, sampling_params=None, prompts=None):

start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"cost time {end_time - start_time}")


def repeat_prompts(prompts, repeat_count):
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts


def main(args):
if args.profile_swap_blocks:
from vllm.worker.cache_engine import CacheEngine
CacheEngine.swap_out = timing_decorator(CacheEngine.swap_out)
CacheEngine.swap_in = timing_decorator(CacheEngine.swap_in)

random.seed(args.seed)

# append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
str(i) + ' '.join(['hi'] * args.document_length)
for i in range(args.num_documents)
]

preemption_mode = ""
if args.block_allocator == "CpuOffloadingBlockAllocator":
preemption_mode = "recompute"
else:
preemption_mode = "swap"

llm = LLM(model=args.model,
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching,
block_allocator=args.block_allocator,
preemption_mode=preemption_mode,
swap_space=args.cpu_memory_gb,
enable_chunked_prefill=False,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=30000)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

prompts = repeat_prompts(prompts, args.repeat_count)

print("------warm up------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)

random.shuffle(prompts)

print("------start generating------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)

if args.profile_swap_blocks:
process_timing_results()


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description=
'Benchmark the performance with or without automatic prefix caching.')
parser.add_argument(
'--model',
type=str,
# this test aims to test long document QA capability,
# so we use llama 3.1 8B as it can process long context
default='meta-llama/Llama-3.1-8B')
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--repeat-count',
type=int,
default=2,
help='Number of times to repeat each prompt')
parser.add_argument(
'--document-length',
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20010,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument('--num-documents',
type=int,
default=8,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument("--seed",
type=int,
default=0,
help='Random seed for reproducibility')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='GPU memory utilization for vLLM. Should be a '
'float point number ranging from 0 to 1. For this '
'test please use a small value so that the GPU '
'cannot hold all KV caches of all documents, '
'and the effect of CPU offloading can be tested.')
parser.add_argument(
'--cpu-memory-gb',
type=float,
default=1,
help="The amount of CPU memory (GB) that is used by vLLM. Not very "
"useful for CpuGpuBlockAllocator, but useful for "
"CpuOffloadingBlockAllocator to have more CPU KV cache space")
parser.add_argument(
'--block-allocator',
type=str,
default='CpuGpuBlockAllocator',
choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'],
help='The block allocator that vLLM uses. Currently'
' can be CpuGpuBlockAllocator (the default) and '
'CpuOffloadingBlockAllocator (experimental) that '
'supports offloading the KV cache to CPU . '
'When using CpuOffloadingBlockAllocator, the '
'preemption mode must be recompute.')

parser.add_argument(
'--profile-swap-blocks',
action='store_true',
help='Profile the swap_blocks function in the custom ops')

args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,4 @@ def main(args):

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)
main(args)
96 changes: 94 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

#include <cstdio>
#include <algorithm>
#include <cassert>
#include <map>
Expand All @@ -21,8 +22,64 @@
typedef __hip_bfloat16 __nv_bfloat16;
#endif

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
namespace vllm {

template <typename T, typename ACC_T>
__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src,
ACC_T src_to_dst, const int num_pages,
const int num_elements_per_page) {
const int64_t srcPageIdx = src_to_dst[blockIdx.x][0];
const int64_t dstPageIdx = src_to_dst[blockIdx.x][1];

const int64_t srcPageOffset = srcPageIdx * num_elements_per_page;
const int64_t dstPageOffset = dstPageIdx * num_elements_per_page;

for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) {
dst[dstPageOffset + i] = src[srcPageOffset + i];
}
}

} // namespace vllm

template <int DTYPE_LEN, typename DTYPE>
void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src,
const torch::Tensor& block_mapping,
const int num_blocks,
const int block_size_in_bytes) {
c10::cuda::CUDAGuard device_guard(block_mapping.device());
auto block_mapping_accessor =
block_mapping.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>();

int num_threads = 1024;
int grid_size = num_blocks;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::paged_copy<<<grid_size, num_threads, 0, stream>>>(
dst, src, block_mapping_accessor, num_blocks,
block_size_in_bytes / DTYPE_LEN);
}

template <typename T>
T* get_kernel_ptr(torch::Tensor& tensor) {
// Get the kernel-accessible pointer of the given type T
// Returns NULL if the tensor is on CPU and non-pinned
torch::Device device = tensor.device();
if (device.is_cuda()) {
return static_cast<T*>(tensor.data_ptr());
} else if (device.is_cpu() && tensor.is_pinned()) {
T* ptr;
cudaHostGetDevicePointer((void**)&ptr, static_cast<T*>(tensor.data_ptr()),
0);
return ptr;
} else if (device.is_cpu()) {
return NULL;
} else {
TORCH_CHECK(false, "Invalid device");
}
}

void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
Expand Down Expand Up @@ -62,6 +119,41 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
int64_t* src_ptr = get_kernel_ptr<int64_t>(src);
int64_t* dst_ptr = get_kernel_ptr<int64_t>(dst);
if (src_ptr == NULL || dst_ptr == NULL) {
// fall back to the slow implementation
swap_blocks_slow(src, dst, block_mapping.cpu());
} else {
// Check the device
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
torch::Device block_mapping_device = block_mapping.device();
TORCH_CHECK(block_mapping_device.is_cuda(), "block_mapping must be on GPU");
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
}
if (src_device.is_cuda()) {
TORCH_CHECK(src_device.index() == block_mapping_device.index(),
"src and block_mapping must be on the same GPU");
}
if (dst_device.is_cuda()) {
TORCH_CHECK(dst_device.index() == block_mapping_device.index(),
"src and block_mapping must be on the same GPU");
}

const int64_t num_blocks = block_mapping.size(0);
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();

launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr,
block_mapping, num_blocks,
block_size_in_bytes);
}
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
Loading
Loading