forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] Represent tokens with identifiable strings (vllm-project#6626
- Loading branch information
Showing
7 changed files
with
138 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Separate these tests out from test_completion and test_chat, because they | ||
# require launching a second server with a different flag. Running both servers | ||
# at the same time on a single node will OOM. | ||
|
||
import pytest | ||
|
||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
from ...utils import RemoteOpenAIServer | ||
from .test_completion import default_server_args # noqa: F401 | ||
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 | ||
from .test_completion import zephyr_lora_files # noqa: F401 | ||
from .test_completion import zephyr_pa_files # noqa: F401 | ||
from .test_completion import MODEL_NAME | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def server_with_return_tokens_as_token_ids_flag( | ||
default_server_args): # noqa: F811 | ||
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"] | ||
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server: | ||
yield remote_server | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_completion_return_tokens_as_token_ids_completion( | ||
server_with_return_tokens_as_token_ids_flag): | ||
client = server_with_return_tokens_as_token_ids_flag.get_async_client() | ||
|
||
completion = await client.completions.create( | ||
model=MODEL_NAME, | ||
# Include Unicode characters to test for dividing a single | ||
# character across multiple tokens: 🎉 is [28705, 31862] for the | ||
# Zephyr tokenizer | ||
prompt="Say 'Hello, world! 🎉'", | ||
echo=True, | ||
temperature=0, | ||
max_tokens=10, | ||
logprobs=1) | ||
|
||
text = completion.choices[0].text | ||
token_strs = completion.choices[0].logprobs.tokens | ||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) | ||
# Check that the token representations are consistent between raw tokens | ||
# and top_logprobs | ||
# Slice off the first one, because there's no scoring associated with BOS | ||
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:] | ||
top_logprob_keys = [ | ||
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs | ||
] | ||
assert token_strs[1:] == top_logprob_keys | ||
|
||
# Check that decoding the tokens gives the expected text | ||
tokens = [int(token.removeprefix("token_id:")) for token in token_strs] | ||
assert text == tokenizer.decode(tokens, skip_special_tokens=True) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_chat_return_tokens_as_token_ids_completion( | ||
server_with_return_tokens_as_token_ids_flag): | ||
client = server_with_return_tokens_as_token_ids_flag.get_async_client() | ||
response = await client.chat.completions.create( | ||
model=MODEL_NAME, | ||
# Include Unicode characters to test for dividing a single | ||
# character across multiple tokens: 🎉 is [28705, 31862] for the | ||
# Zephyr tokenizer | ||
messages=[{ | ||
"role": "system", | ||
"content": "You like to respond in only emojis, like 🎉" | ||
}, { | ||
"role": "user", | ||
"content": "Please write some emojis: 🐱🐶🎉" | ||
}], | ||
temperature=0, | ||
max_tokens=8, | ||
logprobs=True) | ||
|
||
text = response.choices[0].message.content | ||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) | ||
token_ids = [] | ||
for logprob_content in response.choices[0].logprobs.content: | ||
token_ids.append(int(logprob_content.token.removeprefix("token_id:"))) | ||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters