Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Bugfix] Always set enable_chunked_prefill = True for V1 #11061

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class EngineArgs:
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_num_seqs: Optional[int] = None
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
revision: Optional[str] = None
Expand Down Expand Up @@ -205,6 +205,9 @@ def __post_init__(self):
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
# Override max_num_seqs if it's not set by user.
if self.max_num_seqs is None:
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024

# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
Expand Down Expand Up @@ -1225,19 +1228,19 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"

# V1 always uses chunked prefills.
self.enable_chunked_prefill = True
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
# When no user override, set the default values based on the usage
# context.
# TODO(woosuk): Tune the default values for different hardware.
if self.max_num_batched_tokens is None:
# When no user override, set the default values based on the
# usage context.
if usage_context == UsageContext.LLM_CLASS:
logger.warning("Setting max_num_batched_tokens to 8192 "
"for LLM_CLASS usage context.")
self.max_num_seqs = 1024
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
logger.warning("Setting max_num_batched_tokens to 2048 "
"for OPENAI_API_SERVER usage context.")
self.max_num_seqs = 1024
self.max_num_batched_tokens = 2048
logger.warning(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, usage_context.value)

def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
"""
Expand Down
Loading