diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 1211e6ba5aafc..fc6f829c37b06 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -6,6 +6,7 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ +from contextlib import nullcontext import pytest @@ -156,3 +157,68 @@ def test_models_with_fp8_kv_cache( name_0="no_chunked_prefill", name_1="chunked_prefill", ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_with_prefix_caching( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + use_v2_block_manager: bool, + tensor_parallel_size: int, +) -> None: + """ + Checks exact match decode with and without prefix caching + with chunked prefill enabled. + """ + model = "meta-llama/Llama-2-7b-chat-hf" + # The common prompt has 142 tokens with Llama-2 tokenizer. + common_prompt = "You are a helpful AI assistant " * 20 + unique_prompts = [ + "Question", # Warmup + "Question", # Fully cached + "Another question", # Partial cached + ] + full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts] + + max_num_batched_tokens = max_num_seqs = chunk_size + outputs = {} # type: ignore + check_result = True + for enable in (True, False): + with vllm_runner( + model, + dtype="half", + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=True, + enable_prefix_caching=enable, + tensor_parallel_size=tensor_parallel_size, + use_v2_block_manager=use_v2_block_manager, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) as vllm_model: + # It should fail when prefix caching is enable and chunk + # size is not a multiple of block size (16). + should_fail = chunk_size % 16 != 0 and enable + check_result &= not should_fail + outputs[enable] = [] + # Send the request one-by-one to ensure the cache is populated. + with pytest.raises(ValueError) if should_fail else nullcontext(): + for prompt in full_prompts: + outputs[enable] += vllm_model.generate_greedy([prompt], + max_tokens) + + # Check results only if we did not expect a failure. + if check_result: + check_outputs_equal( + outputs_0_lst=outputs[False], + outputs_1_lst=outputs[True], + name_0="w/o prefix caching", + name_1="with prefix caching", + ) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index cd306b9e4d3cc..2ee9f20824f2f 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -595,3 +595,43 @@ def test_sliding_window_multi_seq(): # assert all blocks are free now assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks + + +def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill(): + """When prefix cache and chunked prefill are enabled, the block manager + should only mark a chunk of blocks as computed instead of all blocks. + """ + + block_size = 4 + num_cpu_blocks = 0 + num_gpu_blocks = 16 + block_manager = BlockSpaceManagerV1(block_size, + num_gpu_blocks, + num_cpu_blocks, + watermark=0, + enable_caching=True) + + # Set prompt size to have num_gpu_blocks - 1 full blocks. + prompt_length = block_size * num_gpu_blocks - 1 + + # Allocate (reserve) all blocks. + _, seq_group = create_dummy_prompt("0", + prompt_length, + block_size=block_size) + block_manager.allocate(seq_group) + assert seq_group.seqs[0].n_blocks == num_gpu_blocks + + # 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed. + token_chunk_size = int(block_size * 2.5) + block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) + computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) + assert len(computed_blocks) == 2 + + # Actual computed tokens. + seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size) + + # 2nd chunk: Complete 3rd block and additional 4 blocks. + token_chunk_size = int(block_size * 4.5) + block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) + computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) + assert len(computed_blocks) == 7 diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 6d9c2f3ebba4a..2f6ea632a5d9b 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs(): assert len(get_sequence_groups(out)) == max_seqs assert not running[0].is_prefill() assert not running[1].is_prefill() + + +def test_perfix_caching(): + """Verify allocating full blocks when prefix caching is enabled.""" + block_size = 4 + max_seqs = 10 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 50 + # Verify it is chunked. Note that although the budget is 64-50=14, + # we only allocate full blocks for prefix caching, so only 4*(14//4)=12 + # tokens are allocated. + assert seq_group_meta[1].token_chunk_size == 12 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 62 diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 666723313c829..24ab9eb66194d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -681,14 +681,20 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def compute_full_blocks_in_seq(self, seq: Sequence): + def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): if seq.seq_id not in self.block_tables: return - max_full_block = seq.get_len() // self.block_size - 1 + + # When chunked prefill is enabled, the computed full blocks + # should be calculated based on the number of computed tokens. + max_computed_tokens = (seq.data.get_num_computed_tokens() + + token_chunk_size) + computed_full_blocks = max_computed_tokens // self.block_size + block_table = self.block_tables[seq.seq_id] - if max_full_block == -1: + if computed_full_blocks == 0: return - for i in reversed(range(max_full_block)): + for i in reversed(range(computed_full_blocks)): if block_table[i].computed: break block_table[i].computed = True @@ -718,10 +724,11 @@ def get_common_computed_block_ids( ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] return commonprefix([ids for ids in ids_list if ids != []]) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): if self.enable_caching: for seq in seq_group.get_seqs(): - self.compute_full_blocks_in_seq(seq) + self.compute_full_blocks_in_seq(seq, token_chunk_size) def get_prefix_cache_hit_rate(self, device: Device) -> float: if device == Device.GPU: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 7d2db43cb4602..b06385b062e83 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -290,7 +290,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): self._last_access_blocks_tracker.update_last_access( seq.seq_id, now) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): # If prefix caching is enabled, mark immutable blocks as computed # right after they have been scheduled (for prefill). This assumes # the scheduler is synchronous so blocks are actually computed when diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f16f66e99e7f8..c47d7d8dfb075 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self, seq_group: List[Sequence]) -> List[int]: return [] - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): pass def get_prefix_cache_hit_rate(self, device: Device) -> float: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index becd0d2e7f849..96f8dd851b2f4 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -115,7 +115,8 @@ def get_common_computed_block_ids( pass @abstractmethod - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fbc53afa38f67..51fde6e4eb7a3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1226,7 +1226,8 @@ def schedule( # will crash the vLLM instance / will not retry. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group) + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) self._seq_group_metadata_cache[self.next_cache_id].reset() @@ -1457,10 +1458,27 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 - # Chunk if a running request cannot fit in. - # If number of seq > 1, it means it is doing beam search in a - # decode phase. Do not chunk in that case. + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: - num_new_tokens = min(num_new_tokens, - budget.remaining_token_budget()) + remaining_token_budget = budget.remaining_token_budget() + if self.cache_config.enable_prefix_caching: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block size + # to avoid partial block matching. + block_size = self.cache_config.block_size + reminder = budget.token_budget % block_size + if reminder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {reminder}") + if remaining_token_budget < num_new_tokens: + num_new_tokens = (remaining_token_budget // + block_size) * block_size + else: + num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f556e4ea117ae..2b287a5d27157 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -501,23 +501,48 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) inter_data.prefix_cache_hit = prefix_cache_hit - if self.chunked_prefill_enabled and prefix_cache_hit: - raise RuntimeError( - "chunked prefill cannot be used with prefix caching now.") - - # If prefix cache is hit, advance context length to bypass - # hit blocks. Accordingly, input tokens, position and query length - # have to be updated. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size + + if not prefix_cache_hit: + return + + assert computed_block_nums is not None + # The cache hit prompt tokens in this sequence. Note that + # this may be larger than the sequence length if chunked + # prefill is enabled. + prefix_cache_len = len(computed_block_nums) * self.block_size + # The number of so far computed prompt tokens in this sequence. + context_len = inter_data.context_lens[seq_idx] + # The total number of prompt tokens in this sequence. + # When chunked prefill is enabled, this is the token number of + # computed chunks + current chunk. + seq_len = inter_data.seq_lens[seq_idx] + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + uncomputed_start = prefix_cache_len - context_len inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][context_len:] + seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][context_len:] + seq_idx][uncomputed_start:] + context_len = prefix_cache_len + inter_data.context_lens[seq_idx] = context_len inter_data.query_lens[ seq_idx] = inter_data.seq_lens[seq_idx] - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][-1:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][-1:] + inter_data.query_lens[seq_idx] = 1 + inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, seq_idx: int,