Skip to content

Commit

Permalink
[Frontend] Add per-request number of cached token stats (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong authored and weilong.yu committed Dec 13, 2024
1 parent f5739c2 commit 584aa47
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 23 deletions.
24 changes: 22 additions & 2 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("block_size", [16])
def test_mixed_requests(
hf_runner,
vllm_runner,
Expand All @@ -36,11 +37,12 @@ def test_mixed_requests(
dtype: str,
max_tokens: int,
cached_position: int,
block_size: int,
monkeypatch,
) -> None:
"""
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
override_backend_env_variable(monkeypatch, backend)
Expand All @@ -53,12 +55,30 @@ def test_mixed_requests(
model,
dtype=dtype,
enable_prefix_caching=True,
block_size=block_size,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

# Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
req_outputs = vllm_model.model.generate(example_prompts, greedy_params)

# Verify number of cached tokens
for i in range(len(req_outputs)):
if i == cached_position:
expected_num_cached_tokens = (
len(req_outputs[i].prompt_token_ids) //
block_size) * block_size
else:
expected_num_cached_tokens = 0
assert req_outputs[
i].num_cached_tokens == expected_num_cached_tokens

vllm_outputs = [
(output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text) for output in req_outputs
]

check_outputs_equal(
outputs_0_lst=hf_outputs,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")

return parser

Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list)


class PromptTokenUsageInfo(OpenAIBaseModel):
cached_tokens: Optional[int] = None


class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None


class RequestResponseMetadata(BaseModel):
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def parse_args():
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")

return parser.parse_args()

Expand Down Expand Up @@ -217,6 +222,7 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
Expand Down
35 changes: 22 additions & 13 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
Expand Down Expand Up @@ -49,7 +49,8 @@ def __init__(self,
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(self,
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e

self.enable_prompt_tokens_details = enable_prompt_tokens_details

async def create_chat_completion(
self,
request: ChatCompletionRequest,
Expand Down Expand Up @@ -252,6 +255,7 @@ async def chat_completion_stream_generator(
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
num_cached_tokens = None

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
Expand Down Expand Up @@ -305,6 +309,7 @@ async def chat_completion_stream_generator(
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
num_cached_tokens = res.num_cached_tokens
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
Expand Down Expand Up @@ -530,11 +535,13 @@ async def chat_completion_stream_generator(
# is sent, send the usage
if include_usage:
completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)

final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
Expand Down Expand Up @@ -702,11 +709,13 @@ async def chat_completion_full_generator(
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)

request_metadata.final_usage_info = usage

Expand Down
19 changes: 13 additions & 6 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ class RequestOutput:
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request;
None if decoder-only
encoder_prompt_token_ids: The token IDs of the encoder prompt;
None if decoder-only
encoder_prompt: The encoder prompt string of the request.
None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
"""

def __init__(
Expand All @@ -101,6 +102,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
num_cached_tokens: Optional[int] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
Expand All @@ -112,6 +114,7 @@ def __init__(
self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens

@classmethod
def new(
Expand Down Expand Up @@ -192,13 +195,16 @@ def from_seq_group(

outputs = []
include_prompt = True
# num_cached_tokens should be the same for all the sequences
num_cached_tokens = None
for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)

output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
num_cached_tokens = seq.data.get_num_cached_tokens()

output_logprobs = seq.output_logprobs if include_logprobs else None

Expand Down Expand Up @@ -272,7 +278,7 @@ def from_seq_group(
init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids)
encoder_prompt_token_ids, num_cached_tokens)

if use_cache:
request_output = seq_group.cached_request_output
Expand All @@ -293,7 +299,8 @@ def __repr__(self) -> str:
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request})")
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})")


class EmbeddingRequestOutput:
Expand Down
14 changes: 12 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
# The number of tokens with prefix cache hit.
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)

Expand Down Expand Up @@ -323,6 +325,14 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int):
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE

def get_num_cached_tokens(self) -> int:
"""Return the number of tokens with prefix cache hit."""
return self._num_cached_tokens

def update_num_cached_tokens(self, num_cached_tokens: int):
"""Update the number of tokens with prefix cache hit."""
self._num_cached_tokens = num_cached_tokens

def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
Expand Down Expand Up @@ -379,7 +389,7 @@ def __repr__(self) -> str:

class Sequence:
"""Stores the data, status, and block information of a sequence.
The sequence is constructed from the :data:`DecoderOnlyInputs`
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the :code:`inputs` constructor argument.
Expand Down Expand Up @@ -906,7 +916,7 @@ class SequenceGroupMetadata(
multi_modal_data: Multi modal data.
mm_processor_kwargs: Multimodal input processor / mapper overrides.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
cross_block_table: Optional cross-attention block table associated
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ def _compute_for_prefix_cache_hit(
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
seq_group_metadata.seq_data[inter_data.seq_ids[
seq_idx]].update_num_cached_tokens(prefix_cache_len)

# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
Expand Down

0 comments on commit 584aa47

Please sign in to comment.