Skip to content

Commit

Permalink
[V1] Implement vLLM V1 [1/N] (vllm-project#9289)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Oct 22, 2024
1 parent 3ddbe25 commit 6c5af09
Show file tree
Hide file tree
Showing 27 changed files with 3,058 additions and 180 deletions.
8 changes: 8 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
Expand Down Expand Up @@ -110,6 +111,10 @@ def get_attn_backend(
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
Expand Down Expand Up @@ -215,6 +220,9 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

if envs.VLLM_USE_V1:
return _Backend.FLASH_ATTN_VLLM_V1

# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80):
Expand Down
27 changes: 17 additions & 10 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cloudpickle
import zmq

from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
Expand All @@ -21,12 +21,17 @@
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext

if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine

CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]

Expand Down Expand Up @@ -136,14 +141,16 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,

executor_class = LLMEngine._get_executor_cls(engine_config)

return cls(
ipc_path=ipc_path,
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
use_async_sockets = (engine_config.model_config.use_async_output_proc
and not VLLM_USE_V1)

return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)

def start(self):
try:
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from tqdm import tqdm

from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
Expand All @@ -31,6 +31,11 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of

if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore

logger = init_logger(__name__)


Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -450,6 +451,10 @@ def get_default_config_root():
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),

# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
}

# end-env-vars-definition
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
if sampling_metadata is not None:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
Expand All @@ -69,7 +70,8 @@ def forward(
logits *= self.scale

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)

return logits

Expand Down
168 changes: 3 additions & 165 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional

from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup)

from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup

Expand Down Expand Up @@ -161,167 +163,3 @@ def decode_sequence_inplace(self, seq: Sequence,
seq.output_text += new_decoded_token_text

return len(new_decoded_token_text)


def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""


def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)


# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5


def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset


# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None

# If the new token id is out of bounds, return an empty string.
if 0 <= new_token_id < len(tokenizer):
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
else:
new_tokens = [""]
output_tokens = prev_tokens + new_tokens

# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens

# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)

if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset

new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
Loading

0 comments on commit 6c5af09

Please sign in to comment.