From 8f15b82381530c9938c5a75825b5a0129afa8cc0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 14 Oct 2024 18:44:48 -0700 Subject: [PATCH] Use separate test with chunked prefill; skip empty chunks --- tests/entrypoints/openai/test_chat.py | 49 +------ .../entrypoints/openai/test_chunked_prompt.py | 126 ++++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 6 + vllm/entrypoints/openai/serving_completion.py | 7 +- 4 files changed, 138 insertions(+), 50 deletions(-) create mode 100644 tests/entrypoints/openai/test_chunked_prompt.py diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index adad2e6f724d5..3af0032fd2fb0 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -28,7 +28,7 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 "--dtype", "bfloat16", "--max-model-len", - "None" if MODEL_NAME == "meta-llama/Llama-3.2-1B-Instruct" else "8192", + "8192", "--enforce-eager", # lora config below "--enable-lora", @@ -457,53 +457,6 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, assert last_completion_tokens == 10 -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - ["meta-llama/Llama-3.2-1B-Instruct"], -) -async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI, model_name: str): - # Test stream with long prompt - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" * 3000 - }] - stream = await client.chat.completions.create( - model=model_name, - messages=messages, - max_tokens=10, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": True - }, - logprobs=True, - top_logprobs=5, - ) - - tokens_received = 0 - async for chunk in stream: - assert chunk.usage.prompt_tokens >= 0 - assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) - - if chunk.choices[0].delta.content == "": - # when there is no tokens generated - assert chunk.usage.completion_tokens == 0 - assert chunk.choices[0].logprobs is None - else: - tokens_received += 1 - - if chunk.choices[0].finish_reason is not None: - assert chunk.usage.completion_tokens == tokens_received - - # NOTE: Not sure why, but when I place this after `test_guided_regex_chat` # (i.e. using the same ordering as in the Completions API tests), the test # will fail on the second `guided_decoding_backend` even when I swap their order diff --git a/tests/entrypoints/openai/test_chunked_prompt.py b/tests/entrypoints/openai/test_chunked_prompt.py new file mode 100644 index 0000000000000..61d66365130c7 --- /dev/null +++ b/tests/entrypoints/openai/test_chunked_prompt.py @@ -0,0 +1,126 @@ +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--max-num-seqs", + "128", + "--enable-chunked-prefill", + "--max-num-batched-tokens", + "1000", + # large prompts create a lot of output + "--disable-log-requests", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_completion_stream_options_and_logprobs_with_long_prompts( + client: openai.AsyncOpenAI): + # Test stream with long prompt + prompt = "What is the capital of France?" * 400 + + stream = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + logprobs=5, + ) + + tokens_received = 0 + finished = False + async for chunk in stream: + assert chunk.usage.prompt_tokens >= 0 + assert chunk.usage.completion_tokens >= 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if not finished: + tokens_received += 1 + assert chunk.choices[0].text + + if chunk.choices[0].finish_reason is not None: + finished = True + + if finished: + assert chunk.usage.completion_tokens == tokens_received + + +@pytest.mark.asyncio +async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( + client: openai.AsyncOpenAI): + # Test stream with long prompt + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is the capital of France?" * 400 + }] + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + logprobs=True, + top_logprobs=5, + ) + + tokens_received = 0 + empty_chunks_received = 0 + finished = False + async for chunk in stream: + assert chunk.usage.prompt_tokens >= 0 + assert chunk.usage.completion_tokens >= 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + + if not finished: + if chunk.choices[0].delta.content == "": + # when there is no tokens generated + assert chunk.usage.completion_tokens == 0 + assert chunk.choices[0].logprobs is None + empty_chunks_received += 1 + else: + tokens_received += 1 + + if chunk.choices[0].finish_reason is not None: + finished = True + + if finished: + assert chunk.usage.completion_tokens == tokens_received + + assert empty_chunks_received <= 1 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index acb56e4a886e1..a8b1c94325902 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -435,6 +435,12 @@ async def chat_completion_stream_generator( logprobs = None delta_text = output.text + + if not delta_text and not output.token_ids and \ + not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + delta_message: Optional[DeltaMessage] # handle streaming deltas for tools with named tool_choice diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7aa4587e23c15..1e08cd9712bc0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -274,8 +274,6 @@ async def completion_stream_generator( for output in res.outputs: i = output.index + prompt_idx * num_choices - # TODO(simon): optimize the performance by avoiding full - # text O(n^2) sending. assert request.max_tokens is not None if request.echo and request.max_tokens == 0: @@ -307,6 +305,11 @@ async def completion_stream_generator( delta_token_ids = output.token_ids out_logprobs = output.logprobs + if not delta_text and not delta_token_ids \ + and not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + if request.logprobs is not None: assert out_logprobs is not None, ( "Did not output logprobs")