From f60a8fa481a09f3ec07c15669f069e452291c07a Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 3 Dec 2024 20:22:31 +0000 Subject: [PATCH 01/17] Move to a new branch to fix the DCO issues. Signed-off-by: ApostaC Co-authored-by: KuntaiDu --- benchmarks/benchmark_long_document_qa.py | 164 +++++++++ benchmarks/benchmark_prefix_caching.py | 2 +- csrc/cache_kernels.cu | 76 +++- .../test_cpu_offloading_block_allocator.py | 134 ++++++++ tests/kernels/test_cache.py | 2 +- vllm/config.py | 9 + vllm/core/block/cpu_gpu_block_allocator.py | 9 +- .../block/cpu_offloading_block_allocator.py | 325 ++++++++++++++++++ vllm/core/block_manager.py | 10 +- vllm/core/scheduler.py | 48 ++- vllm/engine/arg_utils.py | 21 ++ vllm/engine/llm_engine.py | 5 +- vllm/entrypoints/llm.py | 2 + vllm/worker/worker.py | 4 +- 14 files changed, 793 insertions(+), 18 deletions(-) create mode 100644 benchmarks/benchmark_long_document_qa.py create mode 100644 tests/core/block/test_cpu_offloading_block_allocator.py create mode 100644 vllm/core/block/cpu_offloading_block_allocator.py diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py new file mode 100644 index 0000000000000..82e37aaccef96 --- /dev/null +++ b/benchmarks/benchmark_long_document_qa.py @@ -0,0 +1,164 @@ +""" +Benchmark the efficiency of prefix caching. + +This script allows you to benchmark the performance of +a model with and without prefix caching using either fixed prompts +or prompts sampled from the ShareGPT dataset. + +Fixed example usage: + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 + +ShareGPT example usage: + # This command samples 20 prompts with input lengths + # between 128 and 256 tokens from the ShareGPT dataset, + # then replicates each prompt 5 times. + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +""" + +import random +import time + +from vllm import LLM, SamplingParams +from vllm.utils import FlexibleArgumentParser + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"cost time {end_time - start_time}") + + +def repeat_prompts(prompts, repeat_count): + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + + +def main(args): + + random.seed(args.seed) + + # append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + preemption_mode = "" + if args.block_allocator == "CpuOffloadingBlockAllocator": + preemption_mode = "recompute" + else: + preemption_mode = "swap" + + llm = LLM(model=args.model, + tokenizer_mode='auto', + trust_remote_code=True, + enforce_eager=True, + tensor_parallel_size=args.tensor_parallel_size, + enable_prefix_caching=args.enable_prefix_caching, + block_allocator=args.block_allocator, + preemption_mode=preemption_mode, + swap_space=args.cpu_memory_gb, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=30000) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + prompts = repeat_prompts(prompts, args.repeat_count) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + parser.add_argument( + '--model', + type=str, + # this test aims to test long document QA capability, + # so we use llama 3.1 8B as it can process long context + default='meta-llama/Llama-3.1-8B') + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20010, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument("--seed", + type=int, + default=0, + help='Random seed for reproducibility') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.5, + help='GPU memory utilization for vLLM. Should be a ' + 'float point number ranging from 0 to 1. For this ' + 'test please use a small value so that the GPU ' + 'cannot hold all KV caches of all documents, ' + 'and the effect of CPU offloading can be tested.') + parser.add_argument( + '--cpu-memory-gb', + type=float, + default=1, + help="The amount of CPU memory (GB) that is used by vLLM. Not very " + "useful for CpuGpuBlockAllocator, but useful for " + "CpuOffloadingBlockAllocator to have more CPU KV cache space") + parser.add_argument( + '--block-allocator', + type=str, + default='CpuGpuBlockAllocator', + choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], + help='The block allocator that vLLM uses. Currently' + ' can be CpuGpuBlockAllocator (the default) and ' + 'CpuOffloadingBlockAllocator (experimental) that ' + 'supports offloading the KV cache to CPU . ' + 'When using CpuOffloadingBlockAllocator, the ' + 'preemption mode must be recompute.') + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 5e9381f712e10..9a8ecae7b65df 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -244,4 +244,4 @@ def main(args): parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..d1eab933dcf93 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -21,8 +21,63 @@ typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { +namespace vllm { + +template +__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src, + ACC_T src_to_dst, const int num_pages, + const int num_elements_per_page) { + const int srcPageIdx = src_to_dst[blockIdx.x][0]; + const int dstPageIdx = src_to_dst[blockIdx.x][1]; + + const int srcPageOffset = srcPageIdx * num_elements_per_page; + const int dstPageOffset = dstPageIdx * num_elements_per_page; + + for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) { + dst[dstPageOffset + i] = src[srcPageOffset + i]; + } +} + +} // namespace vllm + +template +void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src, + const torch::Tensor& block_mapping, + const int num_blocks, + const int block_size_in_bytes) { + auto block_mapping_accessor = + block_mapping.packed_accessor32(); + + int num_threads = 1024; + int grid_size = num_blocks; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::paged_copy<<>>( + dst, src, block_mapping_accessor, num_blocks, + block_size_in_bytes / DTYPE_LEN); +} + +template +T* get_kernel_ptr(torch::Tensor& tensor) { + // Get the kernel-accessible pointer of the given type T + // Returns NULL if the tensor is on CPU and non-pinned + torch::Device device = tensor.device(); + if (device.is_cuda()) { + return static_cast(tensor.data_ptr()); + } else if (device.is_cpu() && tensor.is_pinned()) { + T* ptr; + cudaHostGetDevicePointer((void**)&ptr, static_cast(tensor.data_ptr()), + 0); + return ptr; + } else if (device.is_cpu()) { + return NULL; + } else { + TORCH_CHECK(false, "Invalid device"); + } +} + +void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -62,6 +117,23 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, } } +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + int64_t* src_ptr = get_kernel_ptr(src); + int64_t* dst_ptr = get_kernel_ptr(dst); + if (src_ptr == NULL || dst_ptr == NULL) { + // fall back to the slow implementation + swap_blocks_slow(src, dst, block_mapping.cpu()); + } else { + const int64_t num_blocks = block_mapping.size(0); + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + + launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr, + block_mapping, num_blocks, + block_size_in_bytes); + } +} + namespace vllm { // Grid: (num_layers, num_pairs) diff --git a/tests/core/block/test_cpu_offloading_block_allocator.py b/tests/core/block/test_cpu_offloading_block_allocator.py new file mode 100644 index 0000000000000..d8eec348c5d73 --- /dev/null +++ b/tests/core/block/test_cpu_offloading_block_allocator.py @@ -0,0 +1,134 @@ +import pytest + +from vllm.core.block.cpu_offloading_block_allocator import ( + CpuOffloadingBlockAllocator) +from vllm.utils import Device, chunk_list + + +@pytest.mark.parametrize("num_cpu_blocks", [1024]) +@pytest.mark.parametrize("num_gpu_blocks", [256]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("allocator_type", ["prefix_caching"]) +def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): + allocator = CpuOffloadingBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + gpu_blocks = [ + allocator.allocate_mutable_block(prev_block=None, device=Device.GPU) + for _ in range(num_gpu_blocks) + ] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert len(allocator._uncached_blocks) == num_gpu_blocks + + mapping = allocator.get_and_reset_swaps(0.0) + assert not mapping + assert len(allocator._uncached_blocks) == num_gpu_blocks + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + mapping = allocator.get_and_reset_swaps(1.0) + assert not mapping + assert len(allocator._uncached_blocks) == 0 + + +@pytest.mark.parametrize("num_cpu_blocks", [1024]) +@pytest.mark.parametrize("num_gpu_blocks", [256]) +@pytest.mark.parametrize("block_size", [2]) +@pytest.mark.parametrize("allocator_type", ["prefix_caching"]) +def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): + allocator = CpuOffloadingBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + unique_token_ids = list( + range((num_cpu_blocks + num_gpu_blocks) * block_size)) + gpu_token_ids = list( + chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size)) + gpu_token_ids2 = list( + chunk_list( + unique_token_ids[num_gpu_blocks * block_size:2 * num_gpu_blocks * + block_size], block_size)) + + gpu_blocks = [ + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids + ] + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert len(allocator._uncached_blocks) == num_gpu_blocks + + mapping = allocator.get_and_reset_swaps(0.0) + assert not mapping + assert len(allocator._uncached_blocks) == num_gpu_blocks + + allocator.mark_blocks_as_computed([block.block_id for block in gpu_blocks]) + mapping = allocator.get_and_reset_swaps(1.0) + assert len(mapping) == num_gpu_blocks + assert len(allocator._uncached_blocks) == 0 + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + mapping = allocator.get_and_reset_swaps(1.0) + assert len(mapping) == 0 + assert len(allocator._uncached_blocks) == 0 + + # allocate another gpu sequence to flush out the GPU cache + gpu_blocks = [ + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids2 + ] + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert all([ + not allocator._allocators[Device.GPU].block_is_computed(block.block_id) + for block in gpu_blocks + ]) + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + mapping = allocator.get_and_reset_swaps(2.0) + assert len(mapping) == 0 + assert len(allocator._uncached_blocks) == 0 + + # allocate original gpu sequence. It should hit CPU cache. + gpu_blocks = [ + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids + ] + + delta = num_cpu_blocks - num_gpu_blocks + assert allocator.get_num_free_blocks(Device.CPU) == delta + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert all([ + allocator._allocators[Device.GPU].block_is_computed(block.block_id) + for block in gpu_blocks + ]) + + mapping = allocator.get_and_reset_swaps(3.0) + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed51e2c7..c37438bdff8c3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -362,7 +362,7 @@ def test_swap_blocks( block_mapping = list(zip(src_blocks, dst_blocks)) block_mapping_tensor = torch.tensor(block_mapping, dtype=torch.int64, - device="cpu").view(-1, 2) + device="cuda").view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( diff --git a/vllm/config.py b/vllm/config.py index 971eb36d677b8..b12d89815184d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -759,6 +759,7 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + block_allocator: str = "CpuGpuBlockAllocator", ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -769,6 +770,7 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb + self.block_allocator = block_allocator self._verify_args() self._verify_cache_dtype() @@ -789,6 +791,13 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.block_allocator not in [ + "CpuGpuBlockAllocator", "CpuOffloadingBlockAllocator" + ]: + raise ValueError( + "Only CpuGpuBlockAllocator and CpuOffloadingBlockAllocator is " + "supported. Got %s." % self.block_allocator) + def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3197af3c2b7a4..c1a7216b3604b 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -322,17 +322,20 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. Currently not useful. + + Args: + now (float): The time stamp. Returns: List[Tuple[int, int]]: A mapping of source to destination block IDs. """ - mapping = self._swap_mapping.copy() self._swap_mapping.clear() - return list(mapping.items()) + # return an empty list, to keep compatibility with previous behavior + return [] def find_cached_blocks_prefix( self, diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py new file mode 100644 index 0000000000000..5fbc517477f2c --- /dev/null +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -0,0 +1,325 @@ +"""This file implement a block allocator that supports CPU KV cache offloading + +The key idea of this implementation is to maintain those allocated blocks +that didn't hit the cache, and constantly copy them into CPU after each +scheduler step. + +This idea is borrowed from ConServe +(paper link: https://arxiv.org/abs/2410.01228), based on the assumption +that the CPU-GPU bandwidth is much higher than GPU KV cache generation +throughput. Thanks Yifan for this idea. + +This implementation also allows vLLM to gracefully handle preemption by +recomputation. +""" +from collections import deque +from typing import Deque, Dict, List, Optional, Tuple + +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.utils import Device + + +class CpuOffloadingBlockAllocator(CpuGpuBlockAllocator): + """A block allocator that supports CPU KV cache offloading + + This class extends the `CpuGpuBlockAllocator` so that the CPU can be used + for prefix caching. + + It will internally maintain uncached blocks, and trying to copy uncached + blocks into CPU upon the end of scheduler step (i.e. calling + `get_and_reset_swaps`). + + This implementation also allows vLLM to gracefully handle preemption by + recomputation. + """ + + allocators: Dict[Device, PrefixCachingBlockAllocator] + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Initiate CpuOffloadingBlockAllocator. Similar to + CpuGpuBlockAllocator.create() but only support prefix caching + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuOffloadingBlockAllocator instance + with the specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + assert num_gpu_blocks < num_cpu_blocks, "CPU offloading block "\ + "allocator requires the allocated CPU memory capacity to be larger"\ + " than GPU memory capacity." + block_ids = list(range(num_gpu_blocks + num_cpu_blocks)) + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + assert allocator_type == "prefix_caching", "CpuOffloadingBlock"\ + "Allocator should be only used together with prefix caching." + + # prefix caching block is now the default. + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + + return CpuOffloadingBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__(self, cpu_block_allocator: PrefixCachingBlockAllocator, + gpu_block_allocator: PrefixCachingBlockAllocator): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + super().__init__(cpu_block_allocator, gpu_block_allocator) + self._allocators: Dict[Device, + PrefixCachingBlockAllocator] = { # type: ignore + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator + } + """ + GPU block should only be in one of the following three status: + uncached: allocated blocks that didn't hit any cache + cached: allocated blocks that are cached, either in GPU or in CPU + free: the blocks are not allocated by block allocator + This implementation aims to transform uncacherd blocks to cached blocks + by performing GPU to CPU copy when calling `get_and_reset_swaps` + + As block allocator will automatically track free blocks, and we don't + need to specially handle cached blocks. So we only track uncached blocks + """ + self._uncached_blocks: Deque[Block] = deque() + """ + We probe CPU cache hit by trying to allocate a CPU + block and see if it is computed. + If we hit the CPU cache, we cannot free this CPU block until the end + of scheduler step, in order to avoid the CPU cache being overwritten. + so we track the cpu blocks we allocated, and free it after scheduler + step (i.e. calling `get_and_reset_swaps`). + """ + self._allocated_cpu_blocks: Deque[Block] = deque() + + def allocate_mutable_block(self, prev_block: Optional[Block], + device: Device) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated mutable block. + """ + assert device == Device.GPU, "Calls to CPU offloading block allocator "\ + "should always use Device.GPU --- CPU offloading block allocator "\ + "handles CPU offloading internally."\ + # mark this block as uncached + + block = self._allocators[device].allocate_mutable_block(prev_block) + self._uncached_blocks.append(block) + return block + + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + + assert device == Device.GPU, "Calls to CPU offloading block allocator "\ + "should always use Device.GPU --- CPU offloading block allocator"\ + "handles CPU offloading internally." + + # repeatedly call allocate_immutable_block + # because it handles CPU-GPU offloading related logics. + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device) + blocks.append(prev_block) + return blocks + + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + + assert device == Device.GPU, "Calls to CPU offloading block allocator"\ + " should always use Device.GPU --- CPU offloading block allocator"\ + " handles CPU offloading internally." + + # allocate a GPU block + block = self._allocators[device].allocate_immutable_block( + prev_block, token_ids) + block_id = block.block_id + assert block_id is not None + block_computed = self._allocators[device].block_is_computed(block_id) + + # deal with prefix caching, three cases in total: + # 1. cache hit on GPU + # 2. no cache hit on GPU but cache hit on CPU + # 3. no cache hit + if block_computed: + # cache hit on GPU, no need to put it into uncached blocks + pass + else: + # check if we can hit cache on CPU by trying to allocate CPU block + cpu_block = self._allocators[Device.CPU].allocate_immutable_block( + prev_block, token_ids) + cpu_block_id = cpu_block.block_id + assert cpu_block_id is not None + cpu_block_computed = self._allocators[ + Device.CPU].block_is_computed(cpu_block_id) + if cpu_block_computed: + # CPU cache hit + # mark the GPU block as computed + self._allocators[Device.GPU].mark_blocks_as_computed( + [block_id]) + # copy the CPU cache to GPU + self._swap_mapping[cpu_block_id] = block_id + # and don't free this block until `get_and_reset_swap` is called + self._allocated_cpu_blocks.append(cpu_block) + else: + # No cache hit + # mark the GPU block as uncached + self._uncached_blocks.append(block) + # and free cpu block + self._allocators[Device.CPU].free(cpu_block) + + return block + + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + + raise NotImplementedError("CPU offloading block allocator only " + "support preemption by recomputation.") + + def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: + """Returns and clears the mapping of source to destination block IDs. + Will be called right before scheduler step finishes. + + This function will do the following things: + 1. Iterate over uncached blocks and see if we can copy it to CPU + 2. Update all allocated CPU block time stamp + 3. Free CPU blocks + 4. Return and clear all swapping status + + Args: + now (float): The time stamp used to update CPU access time, so + that CPU evictor can work. + + Returns: + List[Tuple[int, int]]: A mapping of source to destination block IDs. + """ + + allocator = self._allocators[Device.GPU] + cpu_allocator = self._allocators[Device.CPU] + + new_uncached_blocks: Deque[Block] = deque() + + while self._uncached_blocks: + block = self._uncached_blocks.pop() + block_id = block.block_id + + # check if this block is freed + if block_id is None: + # this block is already freed, no longer need to copy it to CPU + continue + + refcount = allocator._refcounter.get(block_id) + assert refcount > 0, "A freed block should have block_id None" + + # check if this block is computed + computed = allocator.block_is_computed(block_id) + if computed: # This block is computed, copy it to CPU + # allocate a block on CPU + cpu_block = cpu_allocator.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) + assert cpu_block.block_id is not None + self._allocated_cpu_blocks.append(cpu_block) + + # mark CPU block as computed + cpu_allocator.mark_blocks_as_computed([cpu_block.block_id]) + + # copy the GPU block to CPU + assert cpu_block.block_id is not None + self._swap_mapping[block_id] = cpu_block.block_id + + continue + + # this block is neither freed nor computed + # keep marking it as uncached + new_uncached_blocks.append(block) + + # update uncached blocks + self._uncached_blocks = new_uncached_blocks + + # iterate over allocated CPU blocks, update access time and free them + # need to update access time so that CPU evictor can work + while self._allocated_cpu_blocks: + cpu_block = self._allocated_cpu_blocks.pop() + assert cpu_block.block_id is not None + # update the access time + cpu_allocator.mark_blocks_as_accessed([cpu_block.block_id], now) + # free the block + cpu_allocator.free(cpu_block) + + # return the mapping + mapping = self._swap_mapping.copy() + self._swap_mapping.clear() + return list(mapping.items()) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 209487c6b4f9e..987aefdb71c11 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,6 +5,8 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.cpu_offloading_block_allocator import ( + CpuOffloadingBlockAllocator) from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) @@ -16,6 +18,11 @@ SeqId = int EncoderSeqId = str +block_allocator_creator = { + "CpuGpuBlockAllocator": CpuGpuBlockAllocator.create, + "CpuOffloadingBlockAllocator": CpuOffloadingBlockAllocator.create, +} + class SelfAttnBlockSpaceManager(BlockSpaceManager): """BlockSpaceManager which manages the allocation of KV cache. @@ -65,6 +72,7 @@ def __init__( watermark: float = 0.01, sliding_window: Optional[int] = None, enable_caching: bool = False, + block_allocator: str = "CpuGpuBlockAllocator", ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks @@ -90,7 +98,7 @@ def __init__( self.watermark_blocks = int(watermark * num_gpu_blocks) - self.block_allocator = CpuGpuBlockAllocator.create( + self.block_allocator = block_allocator_creator[block_allocator]( allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d23009dae01ee..5ad8643ec7cf2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -152,7 +152,9 @@ class SchedulerOutputs: def __post_init__(self): # Swap in and swap out should never happen at the same time. - assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + # NOTE(Kuntai): in CpuOffloadingBlockAllocator swap in and swap out + # will happen at the same time. So we comment out the following line. + # assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: @@ -349,7 +351,8 @@ def __init__( num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + enable_caching=self.cache_config.enable_prefix_caching, + block_allocator=self.cache_config.block_allocator) # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. @@ -1122,6 +1125,21 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_copy = running_scheduled.blocks_to_copy blocks_to_copy.extend(swapped_in.blocks_to_copy) + blocks_to_swap_in = swapped_in.blocks_to_swap_in + blocks_to_swap_out = running_scheduled.blocks_to_swap_out + + # NOTE(Kuntai): extend the swapping list for CPU offloading + block_allocator = self.block_manager.block_allocator + mapping = block_allocator.get_and_reset_swaps(time.time()) + for src, dst in mapping: + # only two possible cases: CPU -> GPU, or GPU -> CPU + if src in block_allocator._allocators[Device.GPU].all_block_ids: + # swap out + blocks_to_swap_out.extend((src, dst)) + else: + # swap in + blocks_to_swap_in.extend((src, dst)) + ignored_seq_groups = prefills.ignored_seq_groups ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) @@ -1200,6 +1218,25 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + blocks_to_swap_in = swapped_in.blocks_to_swap_in + blocks_to_swap_out = running_scheduled.blocks_to_swap_out + + # NOTE(Kuntai): extend the swapping list for CPU offloading + block_allocator = self.block_manager.block_allocator + mapping = block_allocator.get_and_reset_swaps(time.time()) + for src, dst in mapping: + # only two possible cases: CPU -> GPU, or GPU -> CPU + if src in block_allocator._allocators[Device.GPU].all_block_ids: + # swap out + blocks_to_swap_out.extend((src, dst)) + else: + # swap in + blocks_to_swap_in.extend((src, dst)) + # Put prefills first due to Attention backend ordering assumption. scheduled_seq_groups = (prefills.seq_groups + running_scheduled.prefill_seq_groups + @@ -1222,10 +1259,9 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=num_lookahead_slots, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3b776c1d9d39f..2d92be65a8131 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -112,6 +112,7 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None + block_allocator: str = "CpuGpuBlockAllocator" # NOTE(kzawora): default block size for Gaudi should be 128 # smaller sizes still work, but very inefficiently block_size: int = 16 if not current_platform.is_hpu() else 128 @@ -412,6 +413,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='If specified, use nsight to profile Ray workers.') # KV cache arguments + parser.add_argument( + '--block-allocator', + type=str, + default='CpuGpuBlockAllocator', + choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], + help='The block allocator that vLLM uses. Currently' + ' can be CpuGpuBlockAllocator (the default) and ' + 'CpuOffloadingBlockAllocator (experimental) that ' + 'supports offloading the KV cache to CPU . ' + 'When using CpuOffloadingBlockAllocator, the ' + 'preemption mode must be recompute.') parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, @@ -1006,6 +1018,14 @@ def create_engine_config(self, "CPU offload space must be non-negative" f", but got {self.cpu_offload_gb}") + if self.block_allocator == "CpuOffloadingBlockAllocator" and \ + self.preemption_mode == "swap": + raise ValueError( + "CpuOffloadingBlockAllocator only supports preemption by " + "recomputation as it internally offloads the request KV cache " + "to CPU. Please add `--preemption-mode recomputation` to vLLM " + "engine args") + device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() @@ -1028,6 +1048,7 @@ def create_engine_config(self, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, + block_allocator=self.block_allocator, ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index af66b307028cf..35869d37f9ea0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -255,11 +255,11 @@ def __init__( "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, block_allocator=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s, " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " "mm_processor_kwargs=%s, pooler_config=%r," @@ -284,6 +284,7 @@ def __init__( self.model_config.quantization, self.model_config.enforce_eager, self.cache_config.cache_dtype, + self.cache_config.block_allocator, self.model_config.quantization_param_path, self.device_config.device, self.decoding_config, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 65fa9873df28c..ac0e1d779cac3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -163,6 +163,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, + block_allocator: str = "CpuOffloadingBlockAllocator", enforce_eager: Optional[bool] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, @@ -214,6 +215,7 @@ def __init__( gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, + block_allocator=block_allocator, enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 094dd5a5d08b3..ffa3c2af51a2b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -322,10 +322,10 @@ def prepare_worker_input( # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", + device="cuda", dtype=torch.int64).view(-1, 2) blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", + device="cuda", dtype=torch.int64).view(-1, 2) # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy` From e6654f2cf03f08b7228863900c78b0e75891bd8f Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 3 Dec 2024 20:28:36 +0000 Subject: [PATCH 02/17] [Fix] the failed unit tests Signed-off-by: ApostaC --- csrc/cache_kernels.cu | 30 ++++++++++++++++++++++++++---- tests/kernels/test_cache.py | 2 +- vllm/entrypoints/llm.py | 2 +- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d1eab933dcf93..b55eafa286b91 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -11,6 +11,7 @@ #include "quantization/fp8/nvidia/quant_utils.cuh" #endif +#include #include #include #include @@ -27,11 +28,12 @@ template __global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src, ACC_T src_to_dst, const int num_pages, const int num_elements_per_page) { - const int srcPageIdx = src_to_dst[blockIdx.x][0]; - const int dstPageIdx = src_to_dst[blockIdx.x][1]; + const int64_t srcPageIdx = src_to_dst[blockIdx.x][0]; + const int64_t dstPageIdx = src_to_dst[blockIdx.x][1]; - const int srcPageOffset = srcPageIdx * num_elements_per_page; - const int dstPageOffset = dstPageIdx * num_elements_per_page; + + const int64_t srcPageOffset = srcPageIdx * num_elements_per_page; + const int64_t dstPageOffset = dstPageIdx * num_elements_per_page; for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) { dst[dstPageOffset + i] = src[srcPageOffset + i]; @@ -45,6 +47,7 @@ void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src, const torch::Tensor& block_mapping, const int num_blocks, const int block_size_in_bytes) { + c10::cuda::CUDAGuard device_guard(block_mapping.device()); auto block_mapping_accessor = block_mapping.packed_accessor32(); @@ -125,6 +128,25 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, // fall back to the slow implementation swap_blocks_slow(src, dst, block_mapping.cpu()); } else { + // Check the device + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + torch::Device block_mapping_device = block_mapping.device(); + TORCH_CHECK(block_mapping_device.is_cuda(), + "block_mapping must be on GPU"); + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + } + if (src_device.is_cuda()) { + TORCH_CHECK(src_device.index() == block_mapping_device.index(), + "src and block_mapping must be on the same GPU"); + } + if (dst_device.is_cuda()) { + TORCH_CHECK(dst_device.index() == block_mapping_device.index(), + "src and block_mapping must be on the same GPU"); + } + const int64_t num_blocks = block_mapping.size(0); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index c37438bdff8c3..ef90c36dd81ab 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -362,7 +362,7 @@ def test_swap_blocks( block_mapping = list(zip(src_blocks, dst_blocks)) block_mapping_tensor = torch.tensor(block_mapping, dtype=torch.int64, - device="cuda").view(-1, 2) + device=device).view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ac0e1d779cac3..f6f71f65f7a50 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -163,7 +163,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, - block_allocator: str = "CpuOffloadingBlockAllocator", + block_allocator: str = "CpuGpuBlockAllocator", enforce_eager: Optional[bool] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, From ba6c9e3b6fc9969a9e2bdf9af0c7f05e1580e8a7 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Wed, 4 Dec 2024 06:07:35 +0000 Subject: [PATCH 03/17] [Fix] CPU offloading not working bug and [fix] unit test and format issues Signed-off-by: ApostaC --- csrc/cache_kernels.cu | 4 +-- vllm/config.py | 2 +- vllm/core/block/cpu_gpu_block_allocator.py | 9 +++-- .../block/cpu_offloading_block_allocator.py | 32 +++++++++++++---- vllm/core/block/interfaces.py | 17 +++++++++ vllm/core/block/prefix_caching_block.py | 15 +++++++- vllm/core/block_manager.py | 23 ++++++++++++ vllm/core/interfaces.py | 17 +++++++++ vllm/core/placeholder_block_space_manager.py | 4 +++ vllm/core/scheduler.py | 36 ++++++++----------- 10 files changed, 123 insertions(+), 36 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b55eafa286b91..934f37034d11f 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -31,7 +31,6 @@ __global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src, const int64_t srcPageIdx = src_to_dst[blockIdx.x][0]; const int64_t dstPageIdx = src_to_dst[blockIdx.x][1]; - const int64_t srcPageOffset = srcPageIdx * num_elements_per_page; const int64_t dstPageOffset = dstPageIdx * num_elements_per_page; @@ -132,8 +131,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); torch::Device block_mapping_device = block_mapping.device(); - TORCH_CHECK(block_mapping_device.is_cuda(), - "block_mapping must be on GPU"); + TORCH_CHECK(block_mapping_device.is_cuda(), "block_mapping must be on GPU"); if (src_device.is_cuda() && dst_device.is_cuda()) { TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); diff --git a/vllm/config.py b/vllm/config.py index b12d89815184d..426aa0104d62f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -796,7 +796,7 @@ def _verify_args(self) -> None: ]: raise ValueError( "Only CpuGpuBlockAllocator and CpuOffloadingBlockAllocator is " - "supported. Got %s." % self.block_allocator) + f"supported. Got {self.block_allocator}.") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index c1a7216b3604b..b9e91589b3c27 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -322,7 +322,8 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. Currently not useful. @@ -331,11 +332,13 @@ def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: now (float): The time stamp. Returns: - List[Tuple[int, int]]: A mapping of source to destination block IDs. + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. """ self._swap_mapping.clear() # return an empty list, to keep compatibility with previous behavior - return [] + return [], [] def find_cached_blocks_prefix( self, diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py index 5fbc517477f2c..da1af85a70fa2 100644 --- a/vllm/core/block/cpu_offloading_block_allocator.py +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -166,7 +166,6 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block], List[Block]: The newly allocated list of immutable blocks containing the provided block token IDs. """ - assert device == Device.GPU, "Calls to CPU offloading block allocator "\ "should always use Device.GPU --- CPU offloading block allocator"\ "handles CPU offloading internally." @@ -249,7 +248,8 @@ def swap(self, blocks: List[Block], src_device: Device, raise NotImplementedError("CPU offloading block allocator only " "support preemption by recomputation.") - def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. Will be called right before scheduler step finishes. @@ -264,7 +264,9 @@ def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: that CPU evictor can work. Returns: - List[Tuple[int, int]]: A mapping of source to destination block IDs. + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. """ allocator = self._allocators[Device.GPU] @@ -319,7 +321,25 @@ def get_and_reset_swaps(self, now: float) -> List[Tuple[int, int]]: # free the block cpu_allocator.free(cpu_block) - # return the mapping - mapping = self._swap_mapping.copy() + # populate the swap_out list and swap_in list + blocks_to_swap_out = [] + blocks_to_swap_in = [] + for src, dst in self._swap_mapping.items(): + # only two possible cases: CPU -> GPU, or GPU -> CPU + if src in self._allocators[Device.GPU].all_block_ids: + # swap out + blocks_to_swap_out.append((src, dst)) + else: + # swap in + blocks_to_swap_in.append((src, dst)) self._swap_mapping.clear() - return list(mapping.items()) + return blocks_to_swap_out, blocks_to_swap_in + + def will_swap_in_cpu_blocks(self): + """Check if there are CPU blocks that will be swapped in + + Returns: + bool: True if there are CPU blocks that will be swapped in, False + otherwise. + """ + return bool(self._swap_mapping) diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 06f4851af3466..d36001fdeeab4 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -290,3 +290,20 @@ def find_cached_blocks_prefix( device: Device = Device.GPU, ) -> List[int]: pass + + @abstractmethod + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Args: + now (float): The time stamp. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. + """ + pass diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index b736167f6ceb4..2a4b55585e7a9 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -907,6 +907,8 @@ def __init__( # `get_num_cached_tokens` for more details. self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + self._seq_id_has_cpu_blocks: Set[int] = set() + def _update_seq_hashes(self, seq: Sequence) -> None: """Incrementally update the sequence's block hashes and record them.""" assert self._enable_caching @@ -962,7 +964,8 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: # TODO(rickyx): This hack could be removed once we mark blocks as # computed correctly with chunked prefills. - if num_computed_tokens_prev is not None and seq.is_prefill(): + if num_computed_tokens_prev is not None and seq.is_prefill() \ + and seq.seq_id not in self._seq_id_has_cpu_blocks: # For a sequence that is still in prefill, we don't # recompute the number of cached tokens. # This also handles correctly chunked prefill since currently @@ -980,6 +983,14 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens return num_cached_tokens + def on_swap_in_cpu_blocks(self, seq_id: int) -> None: + """Mark the sequence as having CPU blocks swapped in.""" + # NOTE(Yihua): This is a temporary solution to handle the case where + # the CPU offloading is enabled and the sequence has CPU blocks swapped + # in. In this case, the number in self._seq_id_to_num_tokens_computed + # should be invalidated and we need to re-compute it. + self._seq_id_has_cpu_blocks.add(seq_id) + def remove_seq(self, seq_id: int) -> None: """Stop tracking the sequence.""" if not self._enable_caching: @@ -990,6 +1001,8 @@ def remove_seq(self, seq_id: int) -> None: assert seq_id in self._seq_id_to_num_tokens_computed del self._seq_id_to_num_tokens_computed[seq_id] + self._seq_id_has_cpu_blocks.discard(seq_id) + class LastAccessBlocksTracker: """Manages the last access time of the tracked sequences, in order to allow diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 987aefdb71c11..aa11a72e8e631 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -162,6 +162,13 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: # Add blocks to the block table only if the sequence is non empty. block_table.allocate(seq.get_token_ids()) + # If the block allocator is CpuOffloadingBlockAllocator, we need to + # tell the computed_blocks_tracker to invalidate the previous computed + # num cached tokens + if isinstance(self.block_allocator, CpuOffloadingBlockAllocator) and \ + self.block_allocator.will_swap_in_cpu_blocks(): + self._computed_blocks_tracker.on_swap_in_cpu_blocks(seq.seq_id) + return block_table def allocate(self, seq_group: SequenceGroup) -> None: @@ -516,3 +523,19 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: cached in the block manager for the sequence. """ return self._computed_blocks_tracker.get_num_cached_tokens(seq) + + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. + + Args: + now (float): The time stamp. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. + """ + return self.block_allocator.get_and_reset_swaps(now) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b10b8d3f4a5bf..948b2b63643a5 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -125,3 +125,20 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: @abstractmethod def get_num_cached_tokens(self, seq: Sequence) -> int: pass + + @abstractmethod + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. + + Args: + now (float): The time stamp. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. + """ + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index 26d42b7f1790e..4c7ac2f8cfb31 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -92,3 +92,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_num_cached_tokens(self, seq: Sequence) -> int: return 0 + + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + return [], [] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5ad8643ec7cf2..6e88ab77e26d6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1129,16 +1129,12 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_swap_out = running_scheduled.blocks_to_swap_out # NOTE(Kuntai): extend the swapping list for CPU offloading - block_allocator = self.block_manager.block_allocator - mapping = block_allocator.get_and_reset_swaps(time.time()) - for src, dst in mapping: - # only two possible cases: CPU -> GPU, or GPU -> CPU - if src in block_allocator._allocators[Device.GPU].all_block_ids: - # swap out - blocks_to_swap_out.extend((src, dst)) - else: - # swap in - blocks_to_swap_in.extend((src, dst)) + new_swap_out, new_swap_in = \ + self.block_manager.get_and_reset_swaps(time.time()) + for src, dst in new_swap_out: + blocks_to_swap_out.extend((src, dst)) + for src, dst in new_swap_in: + blocks_to_swap_in.extend((src, dst)) ignored_seq_groups = prefills.ignored_seq_groups ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) @@ -1148,8 +1144,8 @@ def _schedule_default(self) -> SchedulerOutputs: num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -1226,16 +1222,12 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: blocks_to_swap_out = running_scheduled.blocks_to_swap_out # NOTE(Kuntai): extend the swapping list for CPU offloading - block_allocator = self.block_manager.block_allocator - mapping = block_allocator.get_and_reset_swaps(time.time()) - for src, dst in mapping: - # only two possible cases: CPU -> GPU, or GPU -> CPU - if src in block_allocator._allocators[Device.GPU].all_block_ids: - # swap out - blocks_to_swap_out.extend((src, dst)) - else: - # swap in - blocks_to_swap_in.extend((src, dst)) + new_swap_out, new_swap_in = \ + self.block_manager.get_and_reset_swaps(time.time()) + for src, dst in new_swap_out: + blocks_to_swap_out.extend((src, dst)) + for src, dst in new_swap_in: + blocks_to_swap_in.extend((src, dst)) # Put prefills first due to Attention backend ordering assumption. scheduled_seq_groups = (prefills.seq_groups + From 1c949851d8ad4e6f48905ec0a74e154f21f1d0ec Mon Sep 17 00:00:00 2001 From: ApostaC Date: Wed, 4 Dec 2024 20:12:23 +0000 Subject: [PATCH 04/17] [fix] broken tests for cpu offloading allocator Signed-off-by: ApostaC --- .../test_cpu_offloading_block_allocator.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/core/block/test_cpu_offloading_block_allocator.py b/tests/core/block/test_cpu_offloading_block_allocator.py index d8eec348c5d73..df4dbc40f12e1 100644 --- a/tests/core/block/test_cpu_offloading_block_allocator.py +++ b/tests/core/block/test_cpu_offloading_block_allocator.py @@ -29,16 +29,18 @@ def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.GPU) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks - mapping = allocator.get_and_reset_swaps(0.0) - assert not mapping + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(0.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks _ = [allocator.free(block) for block in gpu_blocks] assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - mapping = allocator.get_and_reset_swaps(1.0) - assert not mapping + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 @@ -75,21 +77,23 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.GPU) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks - mapping = allocator.get_and_reset_swaps(0.0) - assert not mapping + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(0.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks allocator.mark_blocks_as_computed([block.block_id for block in gpu_blocks]) - mapping = allocator.get_and_reset_swaps(1.0) - assert len(mapping) == num_gpu_blocks + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + assert len(blocks_to_swap_out) + len(blocks_to_swap_in) == num_gpu_blocks assert len(allocator._uncached_blocks) == 0 _ = [allocator.free(block) for block in gpu_blocks] assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - mapping = allocator.get_and_reset_swaps(1.0) - assert len(mapping) == 0 + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 # allocate another gpu sequence to flush out the GPU cache @@ -110,8 +114,9 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, _ = [allocator.free(block) for block in gpu_blocks] assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - mapping = allocator.get_and_reset_swaps(2.0) - assert len(mapping) == 0 + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(2.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 # allocate original gpu sequence. It should hit CPU cache. @@ -130,5 +135,5 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, for block in gpu_blocks ]) - mapping = allocator.get_and_reset_swaps(3.0) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(3.0) assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks From daab0d6aa7b59610d3b67fd41878107cd83340e6 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 8 Dec 2024 21:15:27 +0000 Subject: [PATCH 05/17] [Fix] add the call to get_physical_block_ids Signed-off-by: ApostaC --- vllm/core/block/cpu_gpu_block_allocator.py | 3 ++- vllm/core/block/cpu_offloading_block_allocator.py | 7 ++++++- vllm/core/block/interfaces.py | 3 ++- vllm/core/block_manager.py | 3 ++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index b9e91589b3c27..f1bc7eca7d977 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -334,7 +334,8 @@ def get_and_reset_swaps(self, Returns: A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). Each list is a List[Tuple[int, int]], containing the mapping of - source to destination block IDs. + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. """ self._swap_mapping.clear() # return an empty list, to keep compatibility with previous behavior diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py index da1af85a70fa2..c38815b04a266 100644 --- a/vllm/core/block/cpu_offloading_block_allocator.py +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -266,7 +266,8 @@ def get_and_reset_swaps(self, Returns: A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). Each list is a List[Tuple[int, int]], containing the mapping of - source to destination block IDs. + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. """ allocator = self._allocators[Device.GPU] @@ -328,9 +329,13 @@ def get_and_reset_swaps(self, # only two possible cases: CPU -> GPU, or GPU -> CPU if src in self._allocators[Device.GPU].all_block_ids: # swap out + src = self._allocators[Device.GPU].get_physical_block_id(src) + dst = self._allocators[Device.CPU].get_physical_block_id(dst) blocks_to_swap_out.append((src, dst)) else: # swap in + src = self._allocators[Device.CPU].get_physical_block_id(src) + dst = self._allocators[Device.GPU].get_physical_block_id(dst) blocks_to_swap_in.append((src, dst)) self._swap_mapping.clear() return blocks_to_swap_out, blocks_to_swap_in diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index d36001fdeeab4..469cb1f3f2f9b 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -304,6 +304,7 @@ def get_and_reset_swaps(self, Returns: A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). Each list is a List[Tuple[int, int]], containing the mapping of - source to destination block IDs. + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. """ pass diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index aa11a72e8e631..887791132cd78 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -536,6 +536,7 @@ def get_and_reset_swaps(self, Returns: A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). Each list is a List[Tuple[int, int]], containing the mapping of - source to destination block IDs. + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. """ return self.block_allocator.get_and_reset_swaps(now) From 919e5e39dfaa73006e51c940f41dc7913ed855aa Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 9 Dec 2024 01:05:24 +0000 Subject: [PATCH 06/17] [Add] faster unsafe implementation for get_physical_block_id Signed-off-by: ApostaC --- .../block/cpu_offloading_block_allocator.py | 48 +++++++++++++++++-- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py index c38815b04a266..d03a619e6936e 100644 --- a/vllm/core/block/cpu_offloading_block_allocator.py +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -128,6 +128,9 @@ def __init__(self, cpu_block_allocator: PrefixCachingBlockAllocator, """ self._allocated_cpu_blocks: Deque[Block] = deque() + self.num_gpu_blocks = gpu_block_allocator.get_num_total_blocks() + self.num_cpu_blocks = cpu_block_allocator.get_num_total_blocks() + def allocate_mutable_block(self, prev_block: Optional[Block], device: Device) -> Block: """Allocates a new mutable block on the specified device. @@ -248,6 +251,40 @@ def swap(self, blocks: List[Block], src_device: Device, raise NotImplementedError("CPU offloading block allocator only " "support preemption by recomputation.") + def _is_gpu_block(self, block_id: int) -> bool: + return block_id in self._allocators[Device.GPU].all_block_ids + + def _is_gpu_block_unsafe(self, block_id: int) -> bool: + """Faster version of `_is_gpu_block` that doesn't check the block ID. + But assumes the that the block IDs are assigned contiguously, with GPU + block IDs coming before the CPU block IDs. + """ + return block_id < self.num_gpu_blocks + + def _get_physical_block_id_unsafe(self, block_id: int) -> int: + """Returns the physical block ID of the given block ID. + + This function avoids using the `allocator.get_physical_block_id()` + which is slow (O(NlogN)). Instead, this is based on the assumption + that the block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + + Args: + block_id (int): The block ID to get the physical block ID of. + + Returns: + int: The physical block ID of the given block ID. + + Note: + Please see the implementation of + `CpuOffloadingBlockAllocator.create` for how the block IDs are + assigned. + """ + if self._is_gpu_block_unsafe(block_id): + return block_id + else: + return block_id - self.num_gpu_blocks + def get_and_reset_swaps(self, now: float) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. @@ -327,15 +364,16 @@ def get_and_reset_swaps(self, blocks_to_swap_in = [] for src, dst in self._swap_mapping.items(): # only two possible cases: CPU -> GPU, or GPU -> CPU - if src in self._allocators[Device.GPU].all_block_ids: + #if src in self._allocators[Device.GPU].all_block_ids: + if self._is_gpu_block_unsafe(src): # swap out - src = self._allocators[Device.GPU].get_physical_block_id(src) - dst = self._allocators[Device.CPU].get_physical_block_id(dst) + src = self._get_physical_block_id_unsafe(src) + dst = self._get_physical_block_id_unsafe(dst) blocks_to_swap_out.append((src, dst)) else: # swap in - src = self._allocators[Device.CPU].get_physical_block_id(src) - dst = self._allocators[Device.GPU].get_physical_block_id(dst) + src = self._get_physical_block_id_unsafe(src) + dst = self._get_physical_block_id_unsafe(dst) blocks_to_swap_in.append((src, dst)) self._swap_mapping.clear() return blocks_to_swap_out, blocks_to_swap_in From 063821184400350f554cd7baaaf9bc73c0c0e84a Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Mon, 9 Dec 2024 12:03:49 +0000 Subject: [PATCH 07/17] Feat: support CSR format to construct the swapped blocks sequence IDs with each swapped blocks Signed-off-by: Dahai Tang --- vllm/core/block/cpu_gpu_block_allocator.py | 6 ++- .../block/cpu_offloading_block_allocator.py | 41 ++++++++------ vllm/core/block/interfaces.py | 11 ++-- vllm/core/block_manager.py | 23 +++++++- vllm/core/scheduler.py | 54 ++++++++++++++----- vllm/engine/llm_engine.py | 3 ++ vllm/sequence.py | 6 +++ vllm/worker/worker.py | 12 +++++ vllm/worker/worker_base.py | 9 ++++ 9 files changed, 125 insertions(+), 40 deletions(-) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index f1bc7eca7d977..337899a272fb1 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -322,8 +322,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def get_and_reset_swaps(self, - now: float) -> Tuple[List[Tuple[int, int]], ...]: + def get_and_reset_swaps(self) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. Currently not useful. @@ -341,6 +340,9 @@ def get_and_reset_swaps(self, # return an empty list, to keep compatibility with previous behavior return [], [] + def access_cpu_hit_blocks(self, now: float) -> None: + pass + def find_cached_blocks_prefix( self, block_hashes: List[int], diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py index d03a619e6936e..039b29653ddaa 100644 --- a/vllm/core/block/cpu_offloading_block_allocator.py +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -285,8 +285,7 @@ def _get_physical_block_id_unsafe(self, block_id: int) -> int: else: return block_id - self.num_gpu_blocks - def get_and_reset_swaps(self, - now: float) -> Tuple[List[Tuple[int, int]], ...]: + def get_and_reset_swaps(self) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. Will be called right before scheduler step finishes. @@ -296,10 +295,6 @@ def get_and_reset_swaps(self, 3. Free CPU blocks 4. Return and clear all swapping status - Args: - now (float): The time stamp used to update CPU access time, so - that CPU evictor can work. - Returns: A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). Each list is a List[Tuple[int, int]], containing the mapping of @@ -312,6 +307,7 @@ def get_and_reset_swaps(self, new_uncached_blocks: Deque[Block] = deque() + # XXX(lixiaobai09): may slow for each request to iterate over all? while self._uncached_blocks: block = self._uncached_blocks.pop() block_id = block.block_id @@ -326,7 +322,9 @@ def get_and_reset_swaps(self, # check if this block is computed computed = allocator.block_is_computed(block_id) - if computed: # This block is computed, copy it to CPU + # This block is computed or immutable, copy it to CPU + if computed or \ + (block.content_hash is not None): # allocate a block on CPU cpu_block = cpu_allocator.allocate_immutable_block( prev_block=block.prev_block, token_ids=block.token_ids) @@ -349,16 +347,6 @@ def get_and_reset_swaps(self, # update uncached blocks self._uncached_blocks = new_uncached_blocks - # iterate over allocated CPU blocks, update access time and free them - # need to update access time so that CPU evictor can work - while self._allocated_cpu_blocks: - cpu_block = self._allocated_cpu_blocks.pop() - assert cpu_block.block_id is not None - # update the access time - cpu_allocator.mark_blocks_as_accessed([cpu_block.block_id], now) - # free the block - cpu_allocator.free(cpu_block) - # populate the swap_out list and swap_in list blocks_to_swap_out = [] blocks_to_swap_in = [] @@ -378,6 +366,25 @@ def get_and_reset_swaps(self, self._swap_mapping.clear() return blocks_to_swap_out, blocks_to_swap_in + def access_cpu_hit_blocks(self, now: float) -> None: + ''' + Args: + now (float): The time stamp used to update CPU access time, so + that CPU evictor can work. + ''' + + # iterate over allocated CPU blocks, update access time and free them + # need to update access time so that CPU evictor can work + cpu_allocator = self._allocators[Device.CPU] + while self._allocated_cpu_blocks: + cpu_block = self._allocated_cpu_blocks.pop() + assert cpu_block.block_id is not None + # update the access time + cpu_allocator.mark_blocks_as_accessed([cpu_block.block_id], now) + # free the block + cpu_allocator.free(cpu_block) + + def will_swap_in_cpu_blocks(self): """Check if there are CPU blocks that will be swapped in diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 469cb1f3f2f9b..7a530f7462900 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -292,15 +292,11 @@ def find_cached_blocks_prefix( pass @abstractmethod - def get_and_reset_swaps(self, - now: float) -> Tuple[List[Tuple[int, int]], ...]: + def get_and_reset_swaps(self) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. Currently not useful. - Args: - now (float): The time stamp. - Returns: A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). Each list is a List[Tuple[int, int]], containing the mapping of @@ -308,3 +304,8 @@ def get_and_reset_swaps(self, IDs and it's expected to be used by the cache engine directly. """ pass + + @abstractmethod + def access_cpu_hit_blocks(self, now: float) -> None: + """Access cache hitted blocks on CPU to update last accessed time.""" + pass diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 887791132cd78..30c438b0ba964 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -113,6 +113,11 @@ def __init__( self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) + # request_id -> (blocks_to_swap_out, blocks_to_swap_in) + self.blocks_to_swap_of_request_id: List[ + Tuple[int, Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]] = \ + [] + def can_allocate(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> AllocStatus: @@ -169,6 +174,11 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: self.block_allocator.will_swap_in_cpu_blocks(): self._computed_blocks_tracker.on_swap_in_cpu_blocks(seq.seq_id) + blocks_to_swap_out, blocks_to_swap_in = \ + self.block_allocator.get_and_reset_swaps() + if (len(blocks_to_swap_out) + len(blocks_to_swap_in) > 0): + self.blocks_to_swap_of_request_id.append(( + seq.seq_id, (blocks_to_swap_out, blocks_to_swap_in))) return block_table def allocate(self, seq_group: SequenceGroup) -> None: @@ -256,6 +266,11 @@ def append_slots( ) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() + blocks_to_swap_out, blocks_to_swap_in = \ + self.block_allocator.get_and_reset_swaps() + if (len(blocks_to_swap_out) + len(blocks_to_swap_in) > 0): + self.blocks_to_swap_of_request_id.append(( + seq.seq_id, (blocks_to_swap_out, blocks_to_swap_in))) return new_cows def free(self, seq: Sequence) -> None: @@ -525,7 +540,8 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: return self._computed_blocks_tracker.get_num_cached_tokens(seq) def get_and_reset_swaps(self, - now: float) -> Tuple[List[Tuple[int, int]], ...]: + now: float) -> \ + List[Tuple[int, Tuple[List[Tuple[int, int]], ...]]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. @@ -539,4 +555,7 @@ def get_and_reset_swaps(self, source to destination block IDs. The block IDs are physical block IDs and it's expected to be used by the cache engine directly. """ - return self.block_allocator.get_and_reset_swaps(now) + ret = self.blocks_to_swap_of_request_id + self.block_allocator.access_cpu_hit_blocks(now) + self.blocks_to_swap_of_request_id = [] + return ret diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e88ab77e26d6..ca33dcd0507d0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -138,8 +138,14 @@ class SchedulerOutputs: num_batched_tokens: int # Blocks to swap in. List of CPU -> GPU block number. blocks_to_swap_in: List[Tuple[int, int]] + # swap in requests offsets + swap_in_offsets: List[int] # Blocks to swap out. List of GPU -> CPU block number. blocks_to_swap_out: List[Tuple[int, int]] + # swap out requests offsets + swap_out_offsets: List[int] + # swap requests IDs + swap_sequence_ids: List[int] # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. @@ -1128,13 +1134,20 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_swap_in = swapped_in.blocks_to_swap_in blocks_to_swap_out = running_scheduled.blocks_to_swap_out - # NOTE(Kuntai): extend the swapping list for CPU offloading - new_swap_out, new_swap_in = \ - self.block_manager.get_and_reset_swaps(time.time()) - for src, dst in new_swap_out: - blocks_to_swap_out.extend((src, dst)) - for src, dst in new_swap_in: - blocks_to_swap_in.extend((src, dst)) + swap_out_cnt = len(blocks_to_swap_out) + swap_in_cnt = len(blocks_to_swap_in) + swap_out_offsets = [0, swap_out_cnt] + swap_in_offsets = [0, swap_in_cnt] + swap_sequence_ids = [-1] + for seq_id, (new_swap_out, new_swap_in) in \ + self.block_manager.get_and_reset_swaps(time.time()): + blocks_to_swap_out.extend(new_swap_out) + swap_out_cnt += len(new_swap_out) + swap_out_offsets.append(swap_out_cnt) + blocks_to_swap_in.extend(new_swap_in) + swap_in_cnt += len(new_swap_in) + swap_in_offsets.append(swap_in_cnt) + swap_sequence_ids.append(seq_id) ignored_seq_groups = prefills.ignored_seq_groups ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) @@ -1145,7 +1158,10 @@ def _schedule_default(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, blocks_to_swap_in=blocks_to_swap_in, + swap_in_offsets=swap_in_offsets, blocks_to_swap_out=blocks_to_swap_out, + swap_out_offsets=swap_out_offsets, + swap_sequence_ids=swap_sequence_ids, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -1221,13 +1237,20 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: blocks_to_swap_in = swapped_in.blocks_to_swap_in blocks_to_swap_out = running_scheduled.blocks_to_swap_out - # NOTE(Kuntai): extend the swapping list for CPU offloading - new_swap_out, new_swap_in = \ - self.block_manager.get_and_reset_swaps(time.time()) - for src, dst in new_swap_out: - blocks_to_swap_out.extend((src, dst)) - for src, dst in new_swap_in: - blocks_to_swap_in.extend((src, dst)) + swap_out_cnt = len(blocks_to_swap_out) + swap_in_cnt = len(blocks_to_swap_in) + swap_out_offsets = [0, swap_out_cnt] + swap_in_offsets = [0, swap_in_cnt] + swap_sequence_ids = [-1] + for seq_id, (new_swap_out, new_swap_in) in \ + self.block_manager.get_and_reset_swaps(time.time()): + blocks_to_swap_out.extend(new_swap_out) + swap_out_cnt += len(new_swap_out) + swap_out_offsets.append(swap_out_cnt) + blocks_to_swap_in.extend(new_swap_in) + swap_in_cnt += len(new_swap_in) + swap_in_offsets.append(swap_in_cnt) + swap_sequence_ids.append(seq_id) # Put prefills first due to Attention backend ordering assumption. scheduled_seq_groups = (prefills.seq_groups + @@ -1252,7 +1275,10 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, blocks_to_swap_in=blocks_to_swap_in, + swap_in_offsets=swap_in_offsets, blocks_to_swap_out=blocks_to_swap_out, + swap_out_offsets=swap_out_offsets, + swap_sequence_ids=swap_sequence_ids, blocks_to_copy=blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 35869d37f9ea0..9c55b78bf89f1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1433,7 +1433,10 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + swap_in_offsets=scheduler_outputs.swap_in_offsets, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + swap_out_offsets=scheduler_outputs.swap_out_offsets, + swap_sequence_ids=scheduler_outputs.swap_sequence_ids, blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..fc936284d71a9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1244,9 +1244,15 @@ class ExecuteModelRequest( # Blocks to swap in. List of CPU -> GPU block number. blocks_to_swap_in: List[Tuple[int, int]] = msgspec.field(default_factory=list) + # swap in requests offsets + swap_in_offsets: List[int] = msgspec.field(default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. blocks_to_swap_out: List[Tuple[int, int]] = msgspec.field(default_factory=list) + # swap out requests offsets + swap_out_offsets: List[int] = msgspec.field(default_factory=list) + # swap requests IDs + swap_sequence_ids: List[int] = msgspec.field(default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) # Virtual engine ID for pipeline parallel. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ffa3c2af51a2b..b17b880f7947e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -324,9 +324,18 @@ def prepare_worker_input( blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, device="cuda", dtype=torch.int64).view(-1, 2) + swap_in_offsets = torch.tensor(execute_model_req.swap_in_offsets, + device="cuda", + dtype=torch.int64).view(-1) blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, device="cuda", dtype=torch.int64).view(-1, 2) + swap_out_offsets = torch.tensor(execute_model_req.swap_out_offsets, + device="cuda", + dtype=torch.int64).view(-1) + swap_sequence_ids = torch.tensor(execute_model_req.swap_sequence_ids, + device="cuda", + dtype=torch.int64).view(-1) # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy` # can be used directly within cuda kernels. @@ -337,7 +346,10 @@ def prepare_worker_input( return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_swap_in=blocks_to_swap_in, + swap_in_offsets=swap_in_offsets, blocks_to_swap_out=blocks_to_swap_out, + swap_out_offsets=swap_out_offsets, + swap_sequence_ids=swap_sequence_ids, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7c0bc5a678956..1cf0cd9c0df99 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -144,7 +144,10 @@ class WorkerInput: num_seq_groups: Optional[int] = None blocks_to_swap_in: Optional[torch.Tensor] = None + swap_in_offsets: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None + swap_out_offsets: Optional[torch.Tensor] = None + swap_sequence_ids: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 num_steps: int = 1 @@ -161,7 +164,10 @@ def from_broadcasted_tensor_dict( return cls( num_seq_groups=tensor_dict.pop("num_seq_groups"), blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + swap_in_offsets=tensor_dict.pop("swap_in_offsets"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + swap_out_offsets=tensor_dict.pop("swap_out_offsets"), + swap_sequence_ids=tensor_dict.pop("swap_sequence_ids"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], num_steps=tensor_dict.pop("num_steps"), @@ -175,7 +181,10 @@ def as_broadcastable_tensor_dict( tensor_dict = { "num_seq_groups": self.num_seq_groups, "blocks_to_swap_in": self.blocks_to_swap_in, + "swap_in_offsets": self.swap_in_offsets, "blocks_to_swap_out": self.blocks_to_swap_out, + "swap_out_offsets": self.swap_out_offsets, + "swap_sequence_ids": self.swap_sequence_ids, "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, "num_steps": self.num_steps, From 505e60c5c4d648d461227192ff0c14c64898d1c9 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 10 Dec 2024 17:37:00 +0000 Subject: [PATCH 08/17] Updating the benchmark script with correct usage instructions Signed-off-by: ApostaC --- benchmarks/benchmark_long_document_qa.py | 138 ++++++++++++++++++++--- 1 file changed, 120 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py index 82e37aaccef96..b5ec21bfd0fba 100644 --- a/benchmarks/benchmark_long_document_qa.py +++ b/benchmarks/benchmark_long_document_qa.py @@ -2,35 +2,121 @@ Benchmark the efficiency of prefix caching. This script allows you to benchmark the performance of -a model with and without prefix caching using either fixed prompts -or prompts sampled from the ShareGPT dataset. +a model with prefix-caching or cpu-offloading using fixed prompts Fixed example usage: - python benchmark_prefix_caching.py \ + # This command run the vllm with 50GB CPU memory for offloading + # The workload samples 8 different prompts with a default input + # length of 20010 tokens, then replicates each prompt 2 times. + python benchmark_long_document_qa.py \ --model meta-llama/Llama-2-7b-chat-hf \ --enable-prefix-caching \ - --num-prompts 1 \ - --repeat-count 100 - -ShareGPT example usage: - # This command samples 20 prompts with input lengths - # between 128 and 256 tokens from the ShareGPT dataset, - # then replicates each prompt 5 times. - python benchmark_prefix_caching.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ - --enable-prefix-caching \ - --num-prompts 20 \ - --repeat-count 5 \ - --input-length-range 128:256 + --block-allocator CpuOffloadingBlockAllocator \ + --num-documents 8 \ + --repeat-count 2 \ + --cpu-memory-gb 50 + +Commandline arguments: + + # Basic arguments + --model: The model to use for the benchmark. + + --enable-prefix-caching: Enable prefix caching or not. + + --block-allocator: The block allocator that vLLM uses. + - CpuGpuBlockAllocator: The default block allocator. + - CpuOffloadingBlockAllocator: The block allocator that supports + cpu offloading + + --gpu-memory-utilization: GPU memory utilization for vLLM. + + --cpu-memory-gb: The amount of CPU memory (GB) that is used by vLLM. + NOTE: CPU memory should be larger than GPU KV cache size when + using CpuOffloadingBlockAllocator. + + # Workload-related arguments + --num-documents: The number of documents to sample prompts from. + + --repeat-count: The number of times to repeat each prompt. + + # Other functionality + --seed: Random seed for reproducibility. + + --profile-swap-blocks: Profile the swap_blocks function in the custom ops. """ import random import time +import torch + from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser +""" +HELPER FUNCTIONS FOR PROFILING +""" +execution_times = {} + + +def build_result_dict(start_time, end_time, *args): + total_time = end_time - start_time + length = -1 + if len(args) > 1 and isinstance(args[1], torch.Tensor): + length = len(args[1]) + + return { + "start_time": start_time, + "total_time": total_time, + "swap_len": length + } + + +def timing_decorator(func): + + def wrapper(*args, **kwargs): + global execution_times + torch.cuda.synchronize() + start_time = time.time() # Record the start time + result = func(*args, **kwargs) # Call the wrapped function + torch.cuda.synchronize() + end_time = time.time() # Record the end time + if func.__name__ not in execution_times: + execution_times[func.__name__] = [] + + res = build_result_dict(start_time, end_time, *args) + execution_times[func.__name__].append(res) + return result # Return the result of the original function + + return wrapper + + +def process_timing_results(): + global execution_times + for key in execution_times: + len_to_time = {} + len_to_count = {} + for item in execution_times[key]: + swap_len = item["swap_len"] + if swap_len not in len_to_time: + len_to_time[swap_len] = 0 + len_to_time[swap_len] += item["total_time"] + + if swap_len not in len_to_count: + len_to_count[swap_len] = 0 + len_to_count[swap_len] += 1 + + for swap_len in len_to_time: + total_time = len_to_time[swap_len] + count = len_to_count[swap_len] + print(f"{key} on {swap_len} pages: " + f"{(count * swap_len) / total_time} pages per second") + + +""" +MAIN FUNCTIONS FOR BENCHMARKING +""" + def test_long_document_qa(llm=None, sampling_params=None, prompts=None): @@ -47,6 +133,10 @@ def repeat_prompts(prompts, repeat_count): def main(args): + if args.profile_swap_blocks: + from vllm.worker.cache_engine import CacheEngine + CacheEngine.swap_out = timing_decorator(CacheEngine.swap_out) + CacheEngine.swap_in = timing_decorator(CacheEngine.swap_in) random.seed(args.seed) @@ -72,6 +162,7 @@ def main(args): block_allocator=args.block_allocator, preemption_mode=preemption_mode, swap_space=args.cpu_memory_gb, + enable_chunked_prefill=False, gpu_memory_utilization=args.gpu_memory_utilization, max_model_len=30000) @@ -86,6 +177,8 @@ def main(args): sampling_params=sampling_params, ) + random.shuffle(prompts) + print("------start generating------") test_long_document_qa( llm=llm, @@ -93,6 +186,9 @@ def main(args): sampling_params=sampling_params, ) + if args.profile_swap_blocks: + process_timing_results() + if __name__ == "__main__": parser = FlexibleArgumentParser( @@ -136,7 +232,7 @@ def main(args): help='Random seed for reproducibility') parser.add_argument('--gpu-memory-utilization', type=float, - default=0.5, + default=0.9, help='GPU memory utilization for vLLM. Should be a ' 'float point number ranging from 0 to 1. For this ' 'test please use a small value so that the GPU ' @@ -160,5 +256,11 @@ def main(args): 'supports offloading the KV cache to CPU . ' 'When using CpuOffloadingBlockAllocator, the ' 'preemption mode must be recompute.') + + parser.add_argument( + '--profile-swap-blocks', + action='store_true', + help='Profile the swap_blocks function in the custom ops') + args = parser.parse_args() main(args) From a517a291491188d1ab1b3dcc36b5ae6eb17abfa3 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 10 Dec 2024 17:59:33 +0000 Subject: [PATCH 09/17] make yapf happy Signed-off-by: ApostaC --- benchmarks/benchmark_long_document_qa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py index b5ec21bfd0fba..2c3dc4b013b1d 100644 --- a/benchmarks/benchmark_long_document_qa.py +++ b/benchmarks/benchmark_long_document_qa.py @@ -52,7 +52,6 @@ from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser - """ HELPER FUNCTIONS FOR PROFILING """ From 789b00ef14fffa25c0ce81ecfc6ccbb8751070a7 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 10 Dec 2024 18:03:41 +0000 Subject: [PATCH 10/17] fix format checker issues Signed-off-by: ApostaC --- benchmarks/benchmark_long_document_qa.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py index 2c3dc4b013b1d..8d4425b6fb8a5 100644 --- a/benchmarks/benchmark_long_document_qa.py +++ b/benchmarks/benchmark_long_document_qa.py @@ -52,9 +52,7 @@ from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser -""" -HELPER FUNCTIONS FOR PROFILING -""" + execution_times = {} @@ -112,11 +110,6 @@ def process_timing_results(): f"{(count * swap_len) / total_time} pages per second") -""" -MAIN FUNCTIONS FOR BENCHMARKING -""" - - def test_long_document_qa(llm=None, sampling_params=None, prompts=None): start_time = time.time() From 6d5841f3a2ec7e1028673d3ed10a4df5da12895d Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 11 Dec 2024 12:52:52 +0000 Subject: [PATCH 11/17] Feat: layer-wise transmission Signed-off-by: Dahai Tang --- tests/kernels/test_encoder_decoder_attn.py | 6 +- vllm/attention/layer.py | 47 +++-- .../block/cpu_offloading_block_allocator.py | 1 - vllm/core/block/interfaces.py | 2 +- vllm/core/block_manager.py | 17 +- vllm/core/interfaces.py | 4 +- vllm/core/placeholder_block_space_manager.py | 7 +- vllm/core/scheduler.py | 58 +++--- vllm/engine/llm_engine.py | 3 +- vllm/sequence.py | 6 +- vllm/spec_decode/draft_model_runner.py | 5 +- vllm/v1/worker/gpu_model_runner.py | 3 +- vllm/worker/cache_engine.py | 195 +++++++++++++++++- vllm/worker/cpu_enc_dec_model_runner.py | 3 +- vllm/worker/cpu_model_runner.py | 3 +- vllm/worker/cpu_pooling_model_runner.py | 3 +- vllm/worker/enc_dec_model_runner.py | 3 +- vllm/worker/model_runner.py | 12 +- vllm/worker/pooling_model_runner.py | 3 +- vllm/worker/worker.py | 50 +++-- vllm/worker/worker_base.py | 29 ++- 21 files changed, 355 insertions(+), 105 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index d943b048b7934..c46b02194653a 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -625,7 +625,7 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - with set_forward_context(attn_metadata, vllm_config): + with set_forward_context({"attn_metadata": attn_metadata}, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -680,7 +680,7 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - with set_forward_context(attn_metadata, vllm_config): + with set_forward_context({"attn_metadata": attn_metadata}, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -752,7 +752,7 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - with set_forward_context(attn_metadata, vllm_config): + with set_forward_context({"attn_metadata": attn_metadata}, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e024eef286f05..9a61df000aa22 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -129,15 +129,26 @@ def forward( attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + cache_engine = \ + forward_context.dynamic_forward_context.get("cache_engine") + worker_input= \ + forward_context.dynamic_forward_context.get("worker_input") + if self.use_direct_call: - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + if (cache_engine is not None): + cache_engine.swap_in_sync(worker_input.running_sequence_ids) + ret = self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type) + if (cache_engine is not None): + cache_engine.swap_out(worker_input.blocks_to_swap_out) + return ret elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -150,14 +161,22 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) + if (cache_engine is not None): + cache_engine.swap_in_sync(worker_input.running_sequence_ids) torch.ops.vllm.unified_attention_with_output( query, key, value, output, kv_cache, attn_type, self.layer_name) + if (cache_engine is not None): + cache_engine.swap_out(worker_input.blocks_to_swap_out) return output.view(-1, hidden_size) else: - return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, attn_type, - self.layer_name) + if (cache_engine is not None): + cache_engine.swap_in_sync(worker_input.running_sequence_ids) + ret = torch.ops.vllm.unified_attention(query, key, value, kv_cache, + attn_type, self.layer_name) + if (cache_engine is not None): + cache_engine.swap_out(worker_input.blocks_to_swap_out) + return ret def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -177,7 +196,8 @@ def unified_attention( layer_name: str, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context + attn_metadata = forward_context.dynamic_forward_context.get( + "attn_metadata") self = forward_context.static_forward_context[layer_name] return self.impl.forward(query, key, @@ -219,7 +239,8 @@ def unified_attention_with_output( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context + attn_metadata = forward_context.dynamic_forward_context.get( + "attn_metadata") self = forward_context.static_forward_context[layer_name] self.impl.forward(query, key, diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py index 039b29653ddaa..5ecec5114222f 100644 --- a/vllm/core/block/cpu_offloading_block_allocator.py +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -384,7 +384,6 @@ def access_cpu_hit_blocks(self, now: float) -> None: # free the block cpu_allocator.free(cpu_block) - def will_swap_in_cpu_blocks(self): """Check if there are CPU blocks that will be swapped in diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 7a530f7462900..61a17f7f11063 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -307,5 +307,5 @@ def get_and_reset_swaps(self) -> Tuple[List[Tuple[int, int]], ...]: @abstractmethod def access_cpu_hit_blocks(self, now: float) -> None: - """Access cache hitted blocks on CPU to update last accessed time.""" + """Access cache hit blocks on CPU to update last accessed time.""" pass diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 30c438b0ba964..9f34131bfb79a 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -114,7 +114,7 @@ def __init__( self.block_allocator) # request_id -> (blocks_to_swap_out, blocks_to_swap_in) - self.blocks_to_swap_of_request_id: List[ + self.blocks_to_swap_of_sequence_id: List[ Tuple[int, Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]] = \ [] @@ -177,8 +177,8 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: blocks_to_swap_out, blocks_to_swap_in = \ self.block_allocator.get_and_reset_swaps() if (len(blocks_to_swap_out) + len(blocks_to_swap_in) > 0): - self.blocks_to_swap_of_request_id.append(( - seq.seq_id, (blocks_to_swap_out, blocks_to_swap_in))) + self.blocks_to_swap_of_sequence_id.append( + (seq.seq_id, (blocks_to_swap_out, blocks_to_swap_in))) return block_table def allocate(self, seq_group: SequenceGroup) -> None: @@ -269,8 +269,8 @@ def append_slots( blocks_to_swap_out, blocks_to_swap_in = \ self.block_allocator.get_and_reset_swaps() if (len(blocks_to_swap_out) + len(blocks_to_swap_in) > 0): - self.blocks_to_swap_of_request_id.append(( - seq.seq_id, (blocks_to_swap_out, blocks_to_swap_in))) + self.blocks_to_swap_of_sequence_id.append( + (seq.seq_id, (blocks_to_swap_out, blocks_to_swap_in))) return new_cows def free(self, seq: Sequence) -> None: @@ -541,7 +541,8 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: def get_and_reset_swaps(self, now: float) -> \ - List[Tuple[int, Tuple[List[Tuple[int, int]], ...]]]: + List[Tuple[int, + Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. @@ -555,7 +556,7 @@ def get_and_reset_swaps(self, source to destination block IDs. The block IDs are physical block IDs and it's expected to be used by the cache engine directly. """ - ret = self.blocks_to_swap_of_request_id + ret = self.blocks_to_swap_of_sequence_id self.block_allocator.access_cpu_hit_blocks(now) - self.blocks_to_swap_of_request_id = [] + self.blocks_to_swap_of_sequence_id = [] return ret diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 948b2b63643a5..fb3308508d1bb 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -128,7 +128,9 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: @abstractmethod def get_and_reset_swaps(self, - now: float) -> Tuple[List[Tuple[int, int]], ...]: + now: float) -> \ + List[Tuple[int, + Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index 4c7ac2f8cfb31..a7166aa5ae04b 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -93,6 +93,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_num_cached_tokens(self, seq: Sequence) -> int: return 0 - def get_and_reset_swaps(self, - now: float) -> Tuple[List[Tuple[int, int]], ...]: - return [], [] + def get_and_reset_swaps(self, now: float) -> \ + List[Tuple[int, + Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]]: + return [(-1, ([], []))] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca33dcd0507d0..4cc3bd9edc569 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,12 +140,10 @@ class SchedulerOutputs: blocks_to_swap_in: List[Tuple[int, int]] # swap in requests offsets swap_in_offsets: List[int] + # swap in sequence IDs + swap_in_sequence_ids: List[int] # Blocks to swap out. List of GPU -> CPU block number. blocks_to_swap_out: List[Tuple[int, int]] - # swap out requests offsets - swap_out_offsets: List[int] - # swap requests IDs - swap_sequence_ids: List[int] # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. @@ -1134,20 +1132,21 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_swap_in = swapped_in.blocks_to_swap_in blocks_to_swap_out = running_scheduled.blocks_to_swap_out - swap_out_cnt = len(blocks_to_swap_out) - swap_in_cnt = len(blocks_to_swap_in) - swap_out_offsets = [0, swap_out_cnt] - swap_in_offsets = [0, swap_in_cnt] - swap_sequence_ids = [-1] + swap_in_cnt = 0 + swap_in_offsets = [0] + swap_in_sequence_ids = [] + if (len(blocks_to_swap_in) > 0): + swap_in_cnt += len(blocks_to_swap_in) + swap_in_offsets.append(swap_in_cnt) + swap_in_sequence_ids.append(-1) for seq_id, (new_swap_out, new_swap_in) in \ self.block_manager.get_and_reset_swaps(time.time()): blocks_to_swap_out.extend(new_swap_out) - swap_out_cnt += len(new_swap_out) - swap_out_offsets.append(swap_out_cnt) - blocks_to_swap_in.extend(new_swap_in) - swap_in_cnt += len(new_swap_in) - swap_in_offsets.append(swap_in_cnt) - swap_sequence_ids.append(seq_id) + if (len(new_swap_in) > 0): + blocks_to_swap_in.extend(new_swap_in) + swap_in_cnt += len(new_swap_in) + swap_in_offsets.append(swap_in_cnt) + swap_in_sequence_ids.append(seq_id) ignored_seq_groups = prefills.ignored_seq_groups ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) @@ -1159,9 +1158,8 @@ def _schedule_default(self) -> SchedulerOutputs: budget.num_cached_tokens, blocks_to_swap_in=blocks_to_swap_in, swap_in_offsets=swap_in_offsets, + swap_in_sequence_ids=swap_in_sequence_ids, blocks_to_swap_out=blocks_to_swap_out, - swap_out_offsets=swap_out_offsets, - swap_sequence_ids=swap_sequence_ids, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -1237,20 +1235,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: blocks_to_swap_in = swapped_in.blocks_to_swap_in blocks_to_swap_out = running_scheduled.blocks_to_swap_out - swap_out_cnt = len(blocks_to_swap_out) - swap_in_cnt = len(blocks_to_swap_in) - swap_out_offsets = [0, swap_out_cnt] - swap_in_offsets = [0, swap_in_cnt] - swap_sequence_ids = [-1] + swap_in_cnt = 0 + swap_in_offsets = [0] + swap_in_sequence_ids = [] + if (len(blocks_to_swap_in) > 0): + swap_in_cnt += len(blocks_to_swap_in) + swap_in_offsets.append(swap_in_cnt) + swap_in_sequence_ids.append(-1) for seq_id, (new_swap_out, new_swap_in) in \ self.block_manager.get_and_reset_swaps(time.time()): blocks_to_swap_out.extend(new_swap_out) - swap_out_cnt += len(new_swap_out) - swap_out_offsets.append(swap_out_cnt) - blocks_to_swap_in.extend(new_swap_in) - swap_in_cnt += len(new_swap_in) - swap_in_offsets.append(swap_in_cnt) - swap_sequence_ids.append(seq_id) + if (len(new_swap_in) > 0): + blocks_to_swap_in.extend(new_swap_in) + swap_in_cnt += len(new_swap_in) + swap_in_offsets.append(swap_in_cnt) + swap_in_sequence_ids.append(seq_id) # Put prefills first due to Attention backend ordering assumption. scheduled_seq_groups = (prefills.seq_groups + @@ -1276,9 +1275,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: budget.num_cached_tokens, blocks_to_swap_in=blocks_to_swap_in, swap_in_offsets=swap_in_offsets, + swap_in_sequence_ids=swap_in_sequence_ids, blocks_to_swap_out=blocks_to_swap_out, - swap_out_offsets=swap_out_offsets, - swap_sequence_ids=swap_sequence_ids, blocks_to_copy=blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9c55b78bf89f1..aa5b1abf87784 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1434,9 +1434,8 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, swap_in_offsets=scheduler_outputs.swap_in_offsets, + swap_in_sequence_ids=scheduler_outputs.swap_in_sequence_ids, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - swap_out_offsets=scheduler_outputs.swap_out_offsets, - swap_sequence_ids=scheduler_outputs.swap_sequence_ids, blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, diff --git a/vllm/sequence.py b/vllm/sequence.py index fc936284d71a9..8c1a1327991ff 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1246,13 +1246,11 @@ class ExecuteModelRequest( int]] = msgspec.field(default_factory=list) # swap in requests offsets swap_in_offsets: List[int] = msgspec.field(default_factory=list) + # swap in sequence IDs + swap_in_sequence_ids: List[int] = msgspec.field(default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. blocks_to_swap_out: List[Tuple[int, int]] = msgspec.field(default_factory=list) - # swap out requests offsets - swap_out_offsets: List[int] = msgspec.field(default_factory=list) - # swap requests IDs - swap_sequence_ids: List[int] = msgspec.field(default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) # Virtual engine ID for pipeline parallel. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index fe5fd39f42ac9..5dd51c57bbb9d 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -273,8 +273,9 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - with set_forward_context(model_input.attn_metadata, - self.vllm_config): + with set_forward_context( + {"attn_metadata": model_input.attn_metadata}, + self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4692762493f00..10fc00930a8ef 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -452,7 +452,8 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": attn_metadata}, + self.vllm_config): hidden_states = self.model( input_ids=None, positions=self.positions[:num_input_tokens], diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac3270d1c9909..aa81906ff324b 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,11 +1,13 @@ """CacheEngine class for managing the KV cache.""" -from typing import List +from collections import deque +from typing import Any, Dict, List import torch from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) @@ -84,7 +86,10 @@ def _allocate_kv_cache( device=device)) return kv_cache - def swap_in(self, src_to_dst: torch.Tensor) -> None: + def swap_in(self, + src_to_dst: torch.Tensor, + offsets: torch.Tensor = None, + sequence_ids: torch.Tensor = None) -> None: for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) @@ -94,6 +99,12 @@ def swap_out(self, src_to_dst: torch.Tensor) -> None: self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) + def swap_in_sync(self, sequence_ids: torch.Tensor) -> None: + pass + + def swap_out_sync(self) -> None: + pass + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @@ -117,3 +128,183 @@ def get_cache_block_size( dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype_size = get_dtype_size(dtype) return dtype_size * total + + +class EventPool: + + def __init__(self, reserve_num_events: int, device: torch.device): + self.reserve_num_events = reserve_num_events + self.event_queue: deque[torch.cuda.Event] = deque() + self.device = device + with torch.cuda.device(device): + for i in range(reserve_num_events): + event = torch.cuda.Event() + # create the detail new event + event.record() + event.synchronize() + self.event_queue.append(event) + + def get_event(self) -> torch.cuda.Event: + if (len(self.event_queue) == 0): + with torch.cuda.device(self.device): + event = torch.cuda.Event() + # create the detail new event + event.record() + event.synchronize() + self.event_queue.append(event) + return self.event_queue.popleft() + + def put_event(self, event: torch.cuda.Event): + self.event_queue.append(event) + + def get_events(self, num_events: int) -> list[torch.cuda.Event]: + ret = [] + for i in range(num_events): + ret.append(self.get_event()) + return ret + + def put_events(self, events: list[torch.cuda.Event]): + for event in events: + self.event_queue.append(event) + + +class GPUCacheEngine(CacheEngine): + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: + super().__init__(cache_config, model_config, parallel_config, + device_config) + self.use_fast_path = False + # only these *PUs support fast path + if (current_platform.is_cuda()) or \ + (current_platform.is_rocm()): + self.use_fast_path = True + self.swap_in_stream = None + self.swap_in_event_pool = None + self.swap_in_event_map: Dict[int, Any] = {} + self.swap_out_stream = None + self.swap_out_event = None + self.device = None + self._cur_swap_in_sync_layer = 0 + self._cur_swap_out_layer = 0 + if (not self.use_fast_path): + return + # create device streams and events + self.device = torch.device(torch.cuda.current_device()) + with torch.cuda.device(self.device): + self.swap_in_stream = torch.cuda.Stream() + self.swap_in_event_pool = EventPool(64 * self.num_attention_layers, + self.device) + self.swap_out_stream = torch.cuda.Stream() + self.swap_out_event = torch.cuda.Event() + + def swap_in(self, + src_to_dst: torch.Tensor, + offsets: torch.Tensor = None, + sequence_ids: torch.Tensor = None) -> None: + if (not self.use_fast_path) or \ + (sequence_ids is None) or (sequence_ids.numel() == 0): + super().swap_in(src_to_dst) + return + sequence_ids_numpy = sequence_ids.numpy() + for seq_id in sequence_ids_numpy: + # the first one + if (seq_id == -1): + continue + assert (self.swap_in_event_map.get(seq_id) is None) + assert (self.swap_in_event_pool is not None) + tmp_event_list = self.swap_in_event_pool.get_events( + self.num_attention_layers) + self.swap_in_event_map[seq_id] = tmp_event_list + offsets_numpy = offsets.numpy() + forward_stream = torch.cuda.current_stream() + for idx, seq_id in enumerate(sequence_ids_numpy): + start_idx = offsets_numpy[idx] + last_idx = offsets_numpy[idx + 1] + num_blocks = last_idx - start_idx + swap_in_blocks = src_to_dst.narrow(0, start_idx, num_blocks) + for layer_idx in range(self.num_attention_layers): + if (seq_id == -1): + with torch.cuda.stream(forward_stream): + self.attn_backend.swap_blocks( + self.cpu_cache[layer_idx], + self.gpu_cache[layer_idx], swap_in_blocks) + else: + with torch.cuda.stream(self.swap_in_stream): + self.attn_backend.swap_blocks( + self.cpu_cache[layer_idx], + self.gpu_cache[layer_idx], swap_in_blocks) + self.swap_in_event_map[seq_id][layer_idx].record( + self.swap_in_stream) + + def swap_out( + self, + src_to_dst: torch.Tensor, + ) -> None: + if (src_to_dst.numel() == 0): + return + if (not self.use_fast_path): + cur_layer = self._cur_swap_out_layer + self.attn_backend.swap_blocks(self.gpu_cache[cur_layer], + self.cpu_cache[cur_layer], + src_to_dst) + else: + forward_stream = torch.cuda.current_stream() + assert (self.swap_out_event is not None) + self.swap_out_event.record(forward_stream) + self.swap_out_event.wait(self.swap_out_stream) + with torch.cuda.stream(self.swap_out_stream): + cur_layer = self._cur_swap_out_layer + self.attn_backend.swap_blocks(self.gpu_cache[cur_layer], + self.cpu_cache[cur_layer], + src_to_dst) + self._cur_swap_out_layer = \ + (self._cur_swap_out_layer + 1) % self.num_attention_layers + + def _swap_in_layer_sync_with_seq_ids(self, layer_id: int, + seq_ids: torch.Tensor) -> None: + for seq_id in seq_ids.numpy(): + if (self.swap_in_event_map.get(seq_id) is None): + continue + self.swap_in_event_map[seq_id][layer_id].synchronize() + if (layer_id == self.num_attention_layers - 1): + # recycle the events + for seq_id in seq_ids.numpy(): + if (self.swap_in_event_map.get(seq_id) is None): + continue + event_list = self.swap_in_event_map[seq_id] + assert (self.swap_in_event_pool is not None) + self.swap_in_event_pool.put_events(event_list) + del self.swap_in_event_map[seq_id] + + def _swap_in_layer_all_sync(self, layer_id: int) -> None: + for event_list in self.swap_in_event_map.values(): + event_list[layer_id].synchronize() + # recycle the events + if (layer_id == self.num_attention_layers - 1): + for event_list in self.swap_in_event_map.values(): + assert (self.swap_in_event_pool is not None) + self.swap_in_event_pool.put_events(event_list) + self.swap_in_event_map.clear() + + def swap_in_sync(self, sequence_ids: torch.Tensor) -> None: + if (not self.use_fast_path): + return + if (sequence_ids.numel() == 0): + self._swap_in_layer_all_sync(self._cur_swap_in_sync_layer) + else: + self._swap_in_layer_sync_with_seq_ids(self._cur_swap_in_sync_layer, + sequence_ids) + self._cur_swap_in_sync_layer = \ + (self._cur_swap_in_sync_layer + 1) % self.num_attention_layers + + def swap_out_sync(self) -> None: + if (not self.use_fast_path): + return + assert (self.swap_out_stream is not None) + self.swap_out_stream.synchronize() diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index cc24cfe04d2ba..9a19430bfe85d 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -305,7 +305,8 @@ def execute_model( intermediate_tensors, } - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": model_input.attn_metadata}, + self.vllm_config): hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 420aaf8a1b4cd..5bf4f0de6463c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -528,7 +528,8 @@ def execute_model( execute_model_kwargs.update( {"previous_hidden_states": previous_hidden_states}) - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": model_input.attn_metadata}, + self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 17b2fd2564a04..df43a0f259996 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -69,7 +69,8 @@ def execute_model( intermediate_tensors, } - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": model_input.attn_metadata}, + self.vllm_config): hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5697fbbaa2041..6511e379b2741 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -175,7 +175,8 @@ def execute_model( } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": model_input.attn_metadata}, + self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4388b3c1ee164..c5a91b58ea5d6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -58,6 +58,8 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend + from vllm.worker.cache_engine import CacheEngine + from vllm.worker.worker_base import WorkerInput logger = init_logger(__name__) @@ -1610,6 +1612,8 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + cache_engine: Optional["CacheEngine"] = None, + worker_input: Optional["WorkerInput"] = None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1674,8 +1678,12 @@ def execute_model( model_forward_start.record() if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config): + with set_forward_context( + { + "attn_metadata": model_input.attn_metadata, + "cache_engine": cache_engine, + "worker_input": worker_input, + }, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 1beae1e3884c5..1950dae5b17b4 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -101,7 +101,8 @@ def execute_model( if model_input.token_types is not None: cross_enc_kwargs["token_type_ids"] = model_input.token_types - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": model_input.attn_metadata}, + self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b17b880f7947e..075ef0e04b9e7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -22,7 +22,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.worker.cache_engine import CacheEngine +from vllm.worker.cache_engine import CacheEngine, GPUCacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.pooling_model_runner import PoolingModelRunner @@ -289,8 +289,8 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) + GPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.gpu_cache = [ @@ -319,23 +319,32 @@ def prepare_worker_input( virtual_engine = execute_model_req.virtual_engine num_steps = execute_model_req.num_steps num_seq_groups = len(execute_model_req.seq_group_metadata_list) + seq_ids = [] + for metadata_or_delta in execute_model_req.seq_group_metadata_list: + if isinstance(metadata_or_delta, SequenceGroupMetadata): + seq_ids.extend(metadata_or_delta.seq_data.keys()) + if (len(seq_ids) == 0): + swap_in_offsets = [] + swap_in_sequence_ids = [] + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, device="cuda", dtype=torch.int64).view(-1, 2) swap_in_offsets = torch.tensor(execute_model_req.swap_in_offsets, - device="cuda", + device="cpu", dtype=torch.int64).view(-1) + swap_in_sequence_ids = \ + torch.tensor(execute_model_req.swap_in_sequence_ids, + device="cpu", + dtype=torch.int64).view(-1) blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, device="cuda", dtype=torch.int64).view(-1, 2) - swap_out_offsets = torch.tensor(execute_model_req.swap_out_offsets, - device="cuda", - dtype=torch.int64).view(-1) - swap_sequence_ids = torch.tensor(execute_model_req.swap_sequence_ids, - device="cuda", - dtype=torch.int64).view(-1) + running_sequence_ids = torch.tensor(seq_ids, + device="cpu", + dtype=torch.int64).view(-1) # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy` # can be used directly within cuda kernels. @@ -347,29 +356,30 @@ def prepare_worker_input( num_seq_groups=num_seq_groups, blocks_to_swap_in=blocks_to_swap_in, swap_in_offsets=swap_in_offsets, + swap_in_sequence_ids=swap_in_sequence_ids, blocks_to_swap_out=blocks_to_swap_out, - swap_out_offsets=swap_out_offsets, - swap_sequence_ids=swap_sequence_ids, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, + running_sequence_ids=running_sequence_ids, ) + @torch.inference_mode() + def get_cache_engine(self, worker_input: WorkerInput) -> CacheEngine: + return self.cache_engine[worker_input.virtual_engine] + @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: virtual_engine = worker_input.virtual_engine # Issue cache operations. - if (worker_input.blocks_to_swap_in is not None - and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine[virtual_engine].swap_in( - worker_input.blocks_to_swap_in) - if (worker_input.blocks_to_swap_out is not None - and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine[virtual_engine].swap_out( - worker_input.blocks_to_swap_out) if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in, worker_input.swap_in_offsets, + worker_input.swap_in_sequence_ids) def _get_cached_seq_group_metadata( self, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1cf0cd9c0df99..60185e72036f7 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -2,7 +2,8 @@ import os import time from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + Union) import torch @@ -19,6 +20,9 @@ ModelRunnerBase, ModelRunnerInputBase) +if TYPE_CHECKING: + from vllm.worker.cache_engine import CacheEngine + logger = init_logger(__name__) @@ -145,10 +149,10 @@ class WorkerInput: num_seq_groups: Optional[int] = None blocks_to_swap_in: Optional[torch.Tensor] = None swap_in_offsets: Optional[torch.Tensor] = None + swap_in_sequence_ids: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None - swap_out_offsets: Optional[torch.Tensor] = None - swap_sequence_ids: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None + running_sequence_ids: Optional[torch.Tensor] = None virtual_engine: int = 0 num_steps: int = 1 @@ -165,12 +169,12 @@ def from_broadcasted_tensor_dict( num_seq_groups=tensor_dict.pop("num_seq_groups"), blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), swap_in_offsets=tensor_dict.pop("swap_in_offsets"), + swap_in_sequence_ids=tensor_dict.pop("swap_in_sequence_ids"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - swap_out_offsets=tensor_dict.pop("swap_out_offsets"), - swap_sequence_ids=tensor_dict.pop("swap_sequence_ids"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], num_steps=tensor_dict.pop("num_steps"), + running_sequence_ids=tensor_dict.pop("running_sequence_ids"), ) def as_broadcastable_tensor_dict( @@ -182,12 +186,12 @@ def as_broadcastable_tensor_dict( "num_seq_groups": self.num_seq_groups, "blocks_to_swap_in": self.blocks_to_swap_in, "swap_in_offsets": self.swap_in_offsets, + "swap_in_sequence_ids": self.swap_in_sequence_ids, "blocks_to_swap_out": self.blocks_to_swap_out, - "swap_out_offsets": self.swap_out_offsets, - "swap_sequence_ids": self.swap_sequence_ids, "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, "num_steps": self.num_steps, + "running_sequence_ids": self.running_sequence_ids, } return tensor_dict @@ -246,6 +250,13 @@ def execute_worker(self, worker_input: WorkerInput) -> None: """ raise NotImplementedError + def get_cache_engine(self, worker_input: WorkerInput) \ + -> Optional["CacheEngine"]: + """ + Get the cache engine for the worker. + """ + return None + def _get_worker_input_from_broadcast( self ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ @@ -355,8 +366,12 @@ def execute_model( if self.kv_cache is not None else None, intermediate_tensors=intermediate_tensors, num_steps=num_steps, + cache_engine=self.get_cache_engine(worker_input), + worker_input=worker_input, **kwargs, ) + if (self.get_cache_engine(worker_input) is not None): + self.get_cache_engine(worker_input).swap_out_sync() # type: ignore model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: From 2abdf62ee40bddaf1c60f7eb2eefa2bab2c00533 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Wed, 11 Dec 2024 14:53:40 +0000 Subject: [PATCH 12/17] Fix: set_forward_contex for TPU test Signed-off-by: Dahai Tang --- vllm/worker/tpu_model_runner.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 9a054eb8a4cf7..692d1d7282fa5 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -661,10 +662,12 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, p, - model_input.num_samples, - kv_caches) + with set_forward_context({"attn_metadata": attn_metadata}, + self.vllm_config): + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, + p, model_input.num_samples, + kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -709,10 +712,12 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, p, - model_input.num_samples, - kv_caches) + with set_forward_context({"attn_metadata": attn_metadata}, + self.vllm_config): + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, + p, model_input.num_samples, + kv_caches) self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: From 6f97634caa480c94d7588705e9769db910911784 Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 12 Dec 2024 03:27:13 +0000 Subject: [PATCH 13/17] Fix: set forward context while context is None Signed-off-by: Dahai Tang --- vllm/attention/layer.py | 17 +++++++++++++---- vllm/worker/model_runner.py | 3 ++- vllm/worker/tpu_model_runner.py | 6 ++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5c9002d5bcf87..a2ca6eb97b800 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -131,13 +131,17 @@ def forward( ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() - cache_engine = \ - forward_context.dynamic_forward_context.get("cache_engine") - worker_input= \ - forward_context.dynamic_forward_context.get("worker_input") + cache_engine = None + worker_input = None + if forward_context is not None: + cache_engine = \ + forward_context.dynamic_forward_context.get("cache_engine") + worker_input= \ + forward_context.dynamic_forward_context.get("worker_input") if self.use_direct_call: if (cache_engine is not None): + assert (worker_input is not None) cache_engine.swap_in_sync(worker_input.running_sequence_ids) ret = self.impl.forward(query, key, @@ -148,6 +152,7 @@ def forward( self._v_scale, attn_type=attn_type) if (cache_engine is not None): + assert (worker_input is not None) cache_engine.swap_out(worker_input.blocks_to_swap_out) return ret elif self.use_output: @@ -163,19 +168,23 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) if (cache_engine is not None): + assert (worker_input is not None) cache_engine.swap_in_sync(worker_input.running_sequence_ids) torch.ops.vllm.unified_attention_with_output( query, key, value, output, kv_cache, attn_type, self.layer_name) if (cache_engine is not None): + assert (worker_input is not None) cache_engine.swap_out(worker_input.blocks_to_swap_out) return output.view(-1, hidden_size) else: if (cache_engine is not None): + assert (worker_input is not None) cache_engine.swap_in_sync(worker_input.running_sequence_ids) ret = torch.ops.vllm.unified_attention(query, key, value, kv_cache, attn_type, self.layer_name) if (cache_engine is not None): + assert (worker_input is not None) cache_engine.swap_out(worker_input.blocks_to_swap_out) return ret diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6c232620a9917..87672e70d9edf 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1513,7 +1513,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context({"attn_metadata": attn_metadata}, + self.vllm_config): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 692d1d7282fa5..0a1b264df260a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -264,8 +264,10 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + with set_forward_context({"attn_metadata": attn_metadata}, + self.vllm_config): + self.model(token_ids, position_ids, attn_metadata, input_lens, t, + p, num_samples, kv_caches) def warmup_model( self, From 7a6435d66e5b9dbad6e10d61ad2c0a1d5e29c23a Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 12 Dec 2024 04:21:49 +0000 Subject: [PATCH 14/17] Fix: change model runner arguments to support kwargs Signed-off-by: Dahai Tang --- vllm/worker/cpu_enc_dec_model_runner.py | 1 + vllm/worker/cpu_model_runner.py | 1 + vllm/worker/cpu_pooling_model_runner.py | 1 + vllm/worker/enc_dec_model_runner.py | 1 + vllm/worker/hpu_model_runner.py | 1 + vllm/worker/model_runner.py | 1 + vllm/worker/model_runner_base.py | 1 + vllm/worker/multi_step_model_runner.py | 1 + vllm/worker/neuron_model_runner.py | 1 + vllm/worker/openvino_model_runner.py | 1 + vllm/worker/pooling_model_runner.py | 1 + vllm/worker/tpu_model_runner.py | 1 + vllm/worker/xpu_model_runner.py | 1 + 13 files changed, 13 insertions(+) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index 9a19430bfe85d..44b5cb0c6356e 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -280,6 +280,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 5bf4f0de6463c..9fe2b6fec711c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -512,6 +512,7 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, previous_hidden_states: Optional[torch.Tensor] = None, + **kwargs, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index df43a0f259996..14c3b7c8319bb 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -34,6 +34,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 6511e379b2741..1ae051b6f38ea 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -154,6 +154,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[List[PoolerOutput]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in " diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 99cf9a7e67256..901699e8c61b0 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1893,6 +1893,7 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, warmup_mode=False, + **kwargs: Any, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 87672e70d9edf..ec6ef65ab23d3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1617,6 +1617,7 @@ def execute_model( num_steps: int = 1, cache_engine: Optional["CacheEngine"] = None, worker_input: Optional["WorkerInput"] = None, + **kwargs: Any, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index cd4770202a186..a086f988bc2f4 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -272,6 +272,7 @@ def execute_model( kv_caches: Optional[List[torch.Tensor]], intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, + **kwargs: Any, ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index e08a61e31fe42..c2f4892eff787 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -461,6 +461,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: """ Execute the model for a single step and update multi-step diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index ae4eb6ba6eaec..aa424022111c3 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -309,6 +309,7 @@ def execute_model( kv_caches: Optional[List[torch.Tensor]] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6000e5dfe4e30..b1480a624c9db 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -327,6 +327,7 @@ def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]], + **kwargs: Any, ) -> Optional[SamplerOutput]: ( input_tokens, diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 1950dae5b17b4..4ebab6bd9d3c8 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -49,6 +49,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 0a1b264df260a..b010274b401bc 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -587,6 +587,7 @@ def execute_model( kv_caches: Optional[List[Any]], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> List[SamplerOutput]: assert intermediate_tensors is None if not model_input.is_first_multi_step: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e6322e095bbb9..8c871ffb8189a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -557,6 +557,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs: Any, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( From 47c3557f0c467464ce86a30a29ba1666090afb9c Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 12 Dec 2024 04:27:22 +0000 Subject: [PATCH 15/17] Fix: lint checker Signed-off-by: Dahai Tang --- vllm/worker/openvino_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index b1480a624c9db..7ffdc7a78e524 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple import openvino as ov import torch From 894ab902b091d85dd9a6969e430adbd802d8e36a Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 12 Dec 2024 06:17:14 +0000 Subject: [PATCH 16/17] Fix: cpu offloading block allocator tester Signed-off-by: Dahai Tang --- .../test_cpu_offloading_block_allocator.py | 27 ++++++++++++------- vllm/attention/layer.py | 14 ++++++---- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/tests/core/block/test_cpu_offloading_block_allocator.py b/tests/core/block/test_cpu_offloading_block_allocator.py index df4dbc40f12e1..d4c8b1e37ff12 100644 --- a/tests/core/block/test_cpu_offloading_block_allocator.py +++ b/tests/core/block/test_cpu_offloading_block_allocator.py @@ -29,7 +29,8 @@ def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.GPU) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(0.0) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(0.0) assert len(blocks_to_swap_out) == 0 assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks @@ -38,7 +39,8 @@ def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(1.0) assert len(blocks_to_swap_out) == 0 assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 @@ -77,21 +79,24 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, assert allocator.get_num_free_blocks(Device.GPU) == 0 assert len(allocator._uncached_blocks) == num_gpu_blocks - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(0.0) - assert len(blocks_to_swap_out) == 0 + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(0.0) + assert len(blocks_to_swap_out) == num_gpu_blocks assert len(blocks_to_swap_in) == 0 - assert len(allocator._uncached_blocks) == num_gpu_blocks + assert len(allocator._uncached_blocks) == 0 allocator.mark_blocks_as_computed([block.block_id for block in gpu_blocks]) - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) - assert len(blocks_to_swap_out) + len(blocks_to_swap_in) == num_gpu_blocks + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(1.0) + assert len(blocks_to_swap_out) + len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 _ = [allocator.free(block) for block in gpu_blocks] assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(1.0) assert len(blocks_to_swap_out) == 0 assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 @@ -114,7 +119,8 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, _ = [allocator.free(block) for block in gpu_blocks] assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(2.0) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(2.0) assert len(blocks_to_swap_out) == 0 assert len(blocks_to_swap_in) == 0 assert len(allocator._uncached_blocks) == 0 @@ -135,5 +141,6 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, for block in gpu_blocks ]) - blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(3.0) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps() + allocator.access_cpu_hit_blocks(3.0) assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a2ca6eb97b800..70910393fdb2d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -133,7 +133,7 @@ def forward( forward_context: ForwardContext = get_forward_context() cache_engine = None worker_input = None - if forward_context is not None: + if forward_context.dynamic_forward_context is not None: cache_engine = \ forward_context.dynamic_forward_context.get("cache_engine") worker_input= \ @@ -268,8 +268,10 @@ def unified_attention( layer_name: str, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context.get( - "attn_metadata") + attn_metadata = None + if forward_context.dynamic_forward_context is not None: + attn_metadata = \ + forward_context.dynamic_forward_context.get("attn_metadata") self = forward_context.static_forward_context[layer_name] return self.impl.forward(query, key, @@ -311,8 +313,10 @@ def unified_attention_with_output( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context.get( - "attn_metadata") + attn_metadata = None + if forward_context.dynamic_forward_context is not None: + attn_metadata = \ + forward_context.dynamic_forward_context.get("attn_metadata") self = forward_context.static_forward_context[layer_name] self.impl.forward(query, key, From 80c8c4e7b2b71fa6ae1290f09923e0b91c7c74bf Mon Sep 17 00:00:00 2001 From: Dahai Tang Date: Thu, 12 Dec 2024 07:55:44 +0000 Subject: [PATCH 17/17] Fix: get_cache_engine while self.cache_engine is None Signed-off-by: Dahai Tang --- vllm/worker/worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b99f66fc26984..adca0fc301391 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -366,6 +366,8 @@ def prepare_worker_input( @torch.inference_mode() def get_cache_engine(self, worker_input: WorkerInput) -> CacheEngine: + if (self.cache_engine is None): + return None return self.cache_engine[worker_input.virtual_engine] @torch.inference_mode()