Skip to content

Commit

Permalink
Move to a new branch to fix the DCO issues.
Browse files Browse the repository at this point in the history
Signed-off-by: KuntaiDu <[email protected]>
Co-authored-by: ApostaC <[email protected]>
  • Loading branch information
ApostaC committed Dec 3, 2024
1 parent 7c32b68 commit da35ed9
Show file tree
Hide file tree
Showing 14 changed files with 793 additions and 18 deletions.
164 changes: 164 additions & 0 deletions benchmarks/benchmark_long_document_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Benchmark the efficiency of prefix caching.
This script allows you to benchmark the performance of
a model with and without prefix caching using either fixed prompts
or prompts sampled from the ShareGPT dataset.
Fixed example usage:
python benchmark_prefix_caching.py \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \
--num-prompts 1 \
--repeat-count 100
ShareGPT example usage:
# This command samples 20 prompts with input lengths
# between 128 and 256 tokens from the ShareGPT dataset,
# then replicates each prompt 5 times.
python benchmark_prefix_caching.py \
--model meta-llama/Llama-2-7b-chat-hf \
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
--enable-prefix-caching \
--num-prompts 20 \
--repeat-count 5 \
--input-length-range 128:256
"""

import random
import time

from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser


def test_long_document_qa(llm=None, sampling_params=None, prompts=None):

start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"cost time {end_time - start_time}")


def repeat_prompts(prompts, repeat_count):
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts


def main(args):

random.seed(args.seed)

# append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
str(i) + ' '.join(['hi'] * args.document_length)
for i in range(args.num_documents)
]

preemption_mode = ""
if args.block_allocator == "CpuOffloadingBlockAllocator":
preemption_mode = "recompute"
else:
preemption_mode = "swap"

llm = LLM(model=args.model,
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching,
block_allocator=args.block_allocator,
preemption_mode=preemption_mode,
swap_space=args.cpu_memory_gb,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=30000)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

prompts = repeat_prompts(prompts, args.repeat_count)

print("------warm up------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)

print("------start generating------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description=
'Benchmark the performance with or without automatic prefix caching.')
parser.add_argument(
'--model',
type=str,
# this test aims to test long document QA capability,
# so we use llama 3.1 8B as it can process long context
default='meta-llama/Llama-3.1-8B')
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--repeat-count',
type=int,
default=2,
help='Number of times to repeat each prompt')
parser.add_argument(
'--document-length',
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20010,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument('--num-documents',
type=int,
default=8,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument("--seed",
type=int,
default=0,
help='Random seed for reproducibility')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.5,
help='GPU memory utilization for vLLM. Should be a '
'float point number ranging from 0 to 1. For this '
'test please use a small value so that the GPU '
'cannot hold all KV caches of all documents, '
'and the effect of CPU offloading can be tested.')
parser.add_argument(
'--cpu-memory-gb',
type=float,
default=1,
help="The amount of CPU memory (GB) that is used by vLLM. Not very "
"useful for CpuGpuBlockAllocator, but useful for "
"CpuOffloadingBlockAllocator to have more CPU KV cache space")
parser.add_argument(
'--block-allocator',
type=str,
default='CpuGpuBlockAllocator',
choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'],
help='The block allocator that vLLM uses. Currently'
' can be CpuGpuBlockAllocator (the default) and '
'CpuOffloadingBlockAllocator (experimental) that '
'supports offloading the KV cache to CPU . '
'When using CpuOffloadingBlockAllocator, the '
'preemption mode must be recompute.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,4 @@ def main(args):

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)
main(args)
76 changes: 74 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,63 @@
typedef __hip_bfloat16 __nv_bfloat16;
#endif

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
namespace vllm {

template <typename T, typename ACC_T>
__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src,
ACC_T src_to_dst, const int num_pages,
const int num_elements_per_page) {
const int srcPageIdx = src_to_dst[blockIdx.x][0];
const int dstPageIdx = src_to_dst[blockIdx.x][1];

const int srcPageOffset = srcPageIdx * num_elements_per_page;
const int dstPageOffset = dstPageIdx * num_elements_per_page;

for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) {
dst[dstPageOffset + i] = src[srcPageOffset + i];
}
}

} // namespace vllm

template <int DTYPE_LEN, typename DTYPE>
void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src,
const torch::Tensor& block_mapping,
const int num_blocks,
const int block_size_in_bytes) {
auto block_mapping_accessor =
block_mapping.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>();

int num_threads = 1024;
int grid_size = num_blocks;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::paged_copy<<<grid_size, num_threads, 0, stream>>>(
dst, src, block_mapping_accessor, num_blocks,
block_size_in_bytes / DTYPE_LEN);
}

template <typename T>
T* get_kernel_ptr(torch::Tensor& tensor) {
// Get the kernel-accessible pointer of the given type T
// Returns NULL if the tensor is on CPU and non-pinned
torch::Device device = tensor.device();
if (device.is_cuda()) {
return static_cast<T*>(tensor.data_ptr());
} else if (device.is_cpu() && tensor.is_pinned()) {
T* ptr;
cudaHostGetDevicePointer((void**)&ptr, static_cast<T*>(tensor.data_ptr()),
0);
return ptr;
} else if (device.is_cpu()) {
return NULL;
} else {
TORCH_CHECK(false, "Invalid device");
}
}

void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
Expand Down Expand Up @@ -62,6 +117,23 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
int64_t* src_ptr = get_kernel_ptr<int64_t>(src);
int64_t* dst_ptr = get_kernel_ptr<int64_t>(dst);
if (src_ptr == NULL || dst_ptr == NULL) {
// fall back to the slow implementation
swap_blocks_slow(src, dst, block_mapping.cpu());
} else {
const int64_t num_blocks = block_mapping.size(0);
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();

launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr,
block_mapping, num_blocks,
block_size_in_bytes);
}
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
134 changes: 134 additions & 0 deletions tests/core/block/test_cpu_offloading_block_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest

from vllm.core.block.cpu_offloading_block_allocator import (
CpuOffloadingBlockAllocator)
from vllm.utils import Device, chunk_list


@pytest.mark.parametrize("num_cpu_blocks", [1024])
@pytest.mark.parametrize("num_gpu_blocks", [256])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("allocator_type", ["prefix_caching"])
def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
block_size: int, allocator_type: str):
allocator = CpuOffloadingBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)

assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

gpu_blocks = [
allocator.allocate_mutable_block(prev_block=None, device=Device.GPU)
for _ in range(num_gpu_blocks)
]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == 0
assert len(allocator._uncached_blocks) == num_gpu_blocks

mapping = allocator.get_and_reset_swaps(0.0)
assert not mapping
assert len(allocator._uncached_blocks) == num_gpu_blocks

_ = [allocator.free(block) for block in gpu_blocks]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

mapping = allocator.get_and_reset_swaps(1.0)
assert not mapping
assert len(allocator._uncached_blocks) == 0


@pytest.mark.parametrize("num_cpu_blocks", [1024])
@pytest.mark.parametrize("num_gpu_blocks", [256])
@pytest.mark.parametrize("block_size", [2])
@pytest.mark.parametrize("allocator_type", ["prefix_caching"])
def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
block_size: int, allocator_type: str):
allocator = CpuOffloadingBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)

unique_token_ids = list(
range((num_cpu_blocks + num_gpu_blocks) * block_size))
gpu_token_ids = list(
chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size))
gpu_token_ids2 = list(
chunk_list(
unique_token_ids[num_gpu_blocks * block_size:2 * num_gpu_blocks *
block_size], block_size))

gpu_blocks = [
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids,
device=Device.GPU)
for token_ids in gpu_token_ids
]

assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == 0
assert len(allocator._uncached_blocks) == num_gpu_blocks

mapping = allocator.get_and_reset_swaps(0.0)
assert not mapping
assert len(allocator._uncached_blocks) == num_gpu_blocks

allocator.mark_blocks_as_computed([block.block_id for block in gpu_blocks])
mapping = allocator.get_and_reset_swaps(1.0)
assert len(mapping) == num_gpu_blocks
assert len(allocator._uncached_blocks) == 0

_ = [allocator.free(block) for block in gpu_blocks]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

mapping = allocator.get_and_reset_swaps(1.0)
assert len(mapping) == 0
assert len(allocator._uncached_blocks) == 0

# allocate another gpu sequence to flush out the GPU cache
gpu_blocks = [
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids,
device=Device.GPU)
for token_ids in gpu_token_ids2
]

assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == 0
assert all([
not allocator._allocators[Device.GPU].block_is_computed(block.block_id)
for block in gpu_blocks
])

_ = [allocator.free(block) for block in gpu_blocks]
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

mapping = allocator.get_and_reset_swaps(2.0)
assert len(mapping) == 0
assert len(allocator._uncached_blocks) == 0

# allocate original gpu sequence. It should hit CPU cache.
gpu_blocks = [
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids,
device=Device.GPU)
for token_ids in gpu_token_ids
]

delta = num_cpu_blocks - num_gpu_blocks
assert allocator.get_num_free_blocks(Device.CPU) == delta
assert allocator.get_num_free_blocks(Device.GPU) == 0
assert all([
allocator._allocators[Device.GPU].block_is_computed(block.block_id)
for block in gpu_blocks
])

mapping = allocator.get_and_reset_swaps(3.0)
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
Loading

0 comments on commit da35ed9

Please sign in to comment.