From feafd61542b4c4bb257da1c8820fa7a5924f75d9 Mon Sep 17 00:00:00 2001 From: Elfie Guo Date: Thu, 12 Dec 2024 00:14:46 +0000 Subject: [PATCH] multi-step + chunked-prefill + flashinfer --- tests/multi_step/test_correctness_llm.py | 19 +++++++++++++++-- vllm/attention/backends/flashinfer.py | 26 +++++++++++++++++++----- vllm/worker/model_runner.py | 6 +++++- vllm/worker/multi_step_model_runner.py | 2 +- 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index cc1fd19252019..6fe5e6f76653b 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -5,6 +5,8 @@ import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ @@ -19,10 +21,11 @@ @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm( hf_runner, vllm_runner, @@ -36,6 +39,8 @@ def test_multi_step_llm( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling via sync LLM Engine. @@ -63,6 +68,7 @@ def test_multi_step_llm( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> 1 logprob returned. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -110,10 +116,11 @@ def test_multi_step_llm( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm_w_prompt_logprobs( vllm_runner, example_prompts, @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs( num_prompts: int, num_logprobs: Optional[int], num_prompt_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test prompt logprobs with multi-step scheduling via sync LLM Engine. @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs( note that this argument is not supported by the OpenAI completions endpoint. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs( @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) def test_multi_step_llm_chunked_prefill_prefix_cache( vllm_runner, example_prompts, @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( # # The Incorrect scheduling behavior - if it occurs - will cause an exception # in the model runner resulting from `do_sample=False`. + override_backend_env_variable(monkeypatch, attention_backend) + assert len(example_prompts) >= 2 challenge_prompts = copy.deepcopy(example_prompts) challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..702401b135de4 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -256,7 +256,9 @@ def prepare_graph_input_buffers(self, def begin_forward(self, model_input): assert not self._is_graph_capturing state = self - if model_input.attn_metadata.use_cuda_graph: + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + if use_cuda_graph and is_decode: batch_size = model_input.input_tokens.shape[0] state = (self.runner.graph_runners[model_input.virtual_engine] [batch_size].attn_state) @@ -429,10 +431,24 @@ def advance_step(self, Update metadata in-place to advance one decode step. """ - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with flashinfer yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None assert num_seqs > 0 assert num_queries > 0 diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 26fd486130ce6..ed5db4a8c7fa0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -789,7 +789,7 @@ def _get_cuda_graph_pad_size(self, is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ self.runner.scheduler_config.chunked_prefill_enabled decode_only = self.decode_only or is_mscp - if not decode_only: + if not decode_only or self.runner.is_profile_run: # Early exit so we can treat num_seqs as the batch_size below. return -1 @@ -1025,6 +1025,8 @@ def __init__( self.has_inner_state = model_config.has_inner_state + self.is_profile_run = False + # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -1226,6 +1228,7 @@ def _prepare_model_input_tensors( @torch.inference_mode() def profile_run(self) -> None: + self.is_profile_run = True # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens @@ -1327,6 +1330,7 @@ def profile_run(self) -> None: self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() + self.is_profile_run = False return def remove_all_loras(self): diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index e08a61e31fe42..0b4a32c7f2ad6 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -32,7 +32,7 @@ MULTI_STEP_ATTENTION_BACKENDS = [ "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" ] -MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ -> List[str]: