Skip to content

Commit

Permalink
[Frontend] Represent tokens with identifiable strings (vllm-project#6626
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ezliu authored Jul 25, 2024
1 parent 740374d commit 5689e25
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 19 deletions.
10 changes: 7 additions & 3 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def zephyr_pa_files():


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
args = [
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
Expand Down Expand Up @@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
"128",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server


Expand Down
83 changes: 83 additions & 0 deletions tests/entrypoints/openai/test_return_tokens_as_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Separate these tests out from test_completion and test_chat, because they
# require launching a second server with a different flag. Running both servers
# at the same time on a single node will OOM.

import pytest

from vllm.transformers_utils.tokenizer import get_tokenizer

from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME


@pytest.fixture(scope="module")
def server_with_return_tokens_as_token_ids_flag(
default_server_args): # noqa: F811
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
yield remote_server


@pytest.mark.asyncio
async def test_completion_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()

completion = await client.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: 🎉 is [28705, 31862] for the
# Zephyr tokenizer
prompt="Say 'Hello, world! 🎉'",
echo=True,
temperature=0,
max_tokens=10,
logprobs=1)

text = completion.choices[0].text
token_strs = completion.choices[0].logprobs.tokens
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Check that the token representations are consistent between raw tokens
# and top_logprobs
# Slice off the first one, because there's no scoring associated with BOS
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
top_logprob_keys = [
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
]
assert token_strs[1:] == top_logprob_keys

# Check that decoding the tokens gives the expected text
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
assert text == tokenizer.decode(tokens, skip_special_tokens=True)


@pytest.mark.asyncio
async def test_chat_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
response = await client.chat.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: 🎉 is [28705, 31862] for the
# Zephyr tokenizer
messages=[{
"role": "system",
"content": "You like to respond in only emojis, like 🎉"
}, {
"role": "user",
"content": "Please write some emojis: 🐱🐶🎉"
}],
temperature=0,
max_tokens=8,
logprobs=True)

text = response.choices[0].message.content
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
token_ids = []
for logprob_content in response.choices[0].logprobs.content:
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ async def build_server(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
engine,
Expand All @@ -262,6 +263,7 @@ async def build_server(
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as"
"strings of the form 'token_id:{token_id}' so that tokens that"
"are not JSON-encodable can be identified.")

parser = AsyncEngineArgs.add_cli_args(parser)

Expand Down
23 changes: 16 additions & 7 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def __init__(
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)

self.response_role = response_role

Expand Down Expand Up @@ -522,11 +524,14 @@ def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")))
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
Expand All @@ -546,14 +551,18 @@ def _create_chat_logprobs(
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token,
token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer,
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
Expand Down
19 changes: 15 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ def __init__(
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)

async def create_completion(self, request: CompletionRequest,
raw_request: Request):
Expand Down Expand Up @@ -430,12 +432,17 @@ def _create_completion_logprobs(
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id, tokenizer)
token = self._get_decoded_token(
step_top_logprobs[token_id],
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
Expand All @@ -448,7 +455,11 @@ def _create_completion_logprobs(
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
Expand Down
14 changes: 9 additions & 5 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
prompt_adapter_num_virtual_tokens=num_virtual_tokens))

self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -384,11 +386,13 @@ def _log_inputs(
)

@staticmethod
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"

if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)

0 comments on commit 5689e25

Please sign in to comment.