From 5689e256baf0c45148a01ad147abf11ad82c9690 Mon Sep 17 00:00:00 2001 From: "Evan Z. Liu" Date: Wed, 24 Jul 2024 18:51:00 -0700 Subject: [PATCH] [Frontend] Represent tokens with identifiable strings (#6626) --- tests/entrypoints/openai/test_completion.py | 10 ++- .../openai/test_return_tokens_as_ids.py | 83 +++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 2 + vllm/entrypoints/openai/cli_args.py | 6 ++ vllm/entrypoints/openai/serving_chat.py | 23 +++-- vllm/entrypoints/openai/serving_completion.py | 19 ++++- vllm/entrypoints/openai/serving_engine.py | 14 ++-- 7 files changed, 138 insertions(+), 19 deletions(-) create mode 100644 tests/entrypoints/openai/test_return_tokens_as_ids.py diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 0896e337b5d24..fe00640c0021e 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -55,8 +55,9 @@ def zephyr_pa_files(): @pytest.fixture(scope="module") -def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files): - args = [ +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): + return [ # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", @@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files): "128", ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py new file mode 100644 index 0000000000000..abe413978e0e5 --- /dev/null +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -0,0 +1,83 @@ +# Separate these tests out from test_completion and test_chat, because they +# require launching a second server with a different flag. Running both servers +# at the same time on a single node will OOM. + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer +from .test_completion import default_server_args # noqa: F401 +from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 +from .test_completion import zephyr_lora_files # noqa: F401 +from .test_completion import zephyr_pa_files # noqa: F401 +from .test_completion import MODEL_NAME + + +@pytest.fixture(scope="module") +def server_with_return_tokens_as_token_ids_flag( + default_server_args): # noqa: F811 + args_with_flag = default_server_args + ["--return-tokens-as-token-ids"] + with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_completion_return_tokens_as_token_ids_completion( + server_with_return_tokens_as_token_ids_flag): + client = server_with_return_tokens_as_token_ids_flag.get_async_client() + + completion = await client.completions.create( + model=MODEL_NAME, + # Include Unicode characters to test for dividing a single + # character across multiple tokens: πŸŽ‰ is [28705, 31862] for the + # Zephyr tokenizer + prompt="Say 'Hello, world! πŸŽ‰'", + echo=True, + temperature=0, + max_tokens=10, + logprobs=1) + + text = completion.choices[0].text + token_strs = completion.choices[0].logprobs.tokens + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # Check that the token representations are consistent between raw tokens + # and top_logprobs + # Slice off the first one, because there's no scoring associated with BOS + top_logprobs = completion.choices[0].logprobs.top_logprobs[1:] + top_logprob_keys = [ + next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs + ] + assert token_strs[1:] == top_logprob_keys + + # Check that decoding the tokens gives the expected text + tokens = [int(token.removeprefix("token_id:")) for token in token_strs] + assert text == tokenizer.decode(tokens, skip_special_tokens=True) + + +@pytest.mark.asyncio +async def test_chat_return_tokens_as_token_ids_completion( + server_with_return_tokens_as_token_ids_flag): + client = server_with_return_tokens_as_token_ids_flag.get_async_client() + response = await client.chat.completions.create( + model=MODEL_NAME, + # Include Unicode characters to test for dividing a single + # character across multiple tokens: πŸŽ‰ is [28705, 31862] for the + # Zephyr tokenizer + messages=[{ + "role": "system", + "content": "You like to respond in only emojis, like πŸŽ‰" + }, { + "role": "user", + "content": "Please write some emojis: πŸ±πŸΆπŸŽ‰" + }], + temperature=0, + max_tokens=8, + logprobs=True) + + text = response.choices[0].message.content + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + token_ids = [] + for logprob_content in response.choices[0].logprobs.content: + token_ids.append(int(logprob_content.token.removeprefix("token_id:"))) + assert tokenizer.decode(token_ids, skip_special_tokens=True) == text diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index add5c91900b23..0fe4dd245b5e6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -254,6 +254,7 @@ async def build_server( prompt_adapters=args.prompt_adapters, request_logger=request_logger, chat_template=args.chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_completion = OpenAIServingCompletion( engine, @@ -262,6 +263,7 @@ async def build_server( lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_embedding = OpenAIServingEmbedding( engine, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 64919c8be8642..a4192937980f7 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -128,6 +128,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "using @app.middleware('http'). " "If a class is provided, vLLM will add it to the server " "using app.add_middleware(). ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When --max-logprobs is specified, represents single tokens as" + "strings of the form 'token_id:{token_id}' so that tokens that" + "are not JSON-encodable can be identified.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3899509ef3ff4..012f70e661100 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -50,13 +50,15 @@ def __init__( prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, prompt_adapters=prompt_adapters, - request_logger=request_logger) + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) self.response_role = response_role @@ -522,11 +524,14 @@ def _get_top_logprobs( self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: return [ - ChatCompletionLogProb( - token=(token := self._get_decoded_token(p[1], p[0], - tokenizer)), - logprob=max(p[1].logprob, -9999.0), - bytes=list(token.encode("utf-8", errors="replace"))) + ChatCompletionLogProb(token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids)), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + token.encode("utf-8", errors="replace"))) for i, p in enumerate(logprobs.items()) if top_logprobs and i < top_logprobs ] @@ -546,6 +551,8 @@ def _create_chat_logprobs( step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" logprobs_content.append( ChatCompletionLogProbsContent( token=token, @@ -553,7 +560,9 @@ def _create_chat_logprobs( else: logprobs_content.append( ChatCompletionLogProbsContent( - token=step_top_logprobs[token_id].decoded_token, + token=self._get_decoded_token( + step_top_logprobs[token_id], token_id, tokenizer, + self.return_tokens_as_token_ids), logprob=max(step_top_logprobs[token_id].logprob, -9999.0), bytes=list( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6aef4c9f96150..73e420141813e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -51,13 +51,15 @@ def __init__( lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, prompt_adapters=prompt_adapters, - request_logger=request_logger) + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -430,12 +432,17 @@ def _create_completion_logprobs( step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" out_tokens.append(token) out_token_logprobs.append(None) out_top_logprobs.append(None) else: - token = self._get_decoded_token(step_top_logprobs[token_id], - token_id, tokenizer) + token = self._get_decoded_token( + step_top_logprobs[token_id], + token_id, + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids) token_logprob = max(step_top_logprobs[token_id].logprob, -9999.0) out_tokens.append(token) @@ -448,7 +455,11 @@ def _create_completion_logprobs( out_top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses - self._get_decoded_token(top_lp[1], top_lp[0], tokenizer): + self._get_decoded_token( + top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids): max(top_lp[1].logprob, -9999.0) for i, top_lp in enumerate(step_top_logprobs.items()) if num_output_top_logprobs >= i diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8c6bd10b9b4d4..321c9ac2c1d5f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -68,6 +68,7 @@ def __init__( lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, ): super().__init__() @@ -102,6 +103,7 @@ def __init__( prompt_adapter_num_virtual_tokens=num_virtual_tokens)) self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" @@ -384,11 +386,13 @@ def _log_inputs( ) @staticmethod - def _get_decoded_token( - logprob: Logprob, - token_id: int, - tokenizer: AnyTokenizer, - ) -> str: + def _get_decoded_token(logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False) -> str: + if return_as_token_id: + return f"token_id:{token_id}" + if logprob.decoded_token is not None: return logprob.decoded_token return tokenizer.decode(token_id)