Skip to content

Commit

Permalink
Support continue final message (#2733)
Browse files Browse the repository at this point in the history
* feat: support continue_final_message param in chat request

* feat: add test for continue final message

* fix: bump openapi docs

* fix: remove continue_final_message chat request param

* fix: remove unneeded launcher args in continue test

* fix: bump test output

* fix: remove accidentally included guideline from rebase

* fix: remove guideline tests

* fix: adjust continuation tests expected text

* fix: replace expected output for continue test
  • Loading branch information
drbh authored Nov 28, 2024
1 parent caff779 commit d471805
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1",
"role": "assistant"
}
}
],
"created": 1732541189,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 49,
"total_tokens": 79
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds",
"role": "assistant"
}
}
],
"created": 1732541190,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 73,
"total_tokens": 103
}
}
76 changes: 76 additions & 0 deletions integration-tests/models/test_continue_final_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
import requests


@pytest.fixture(scope="module")
def llama_continue_final_message_handle(launcher):
with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0") as handle:
yield handle


@pytest.fixture(scope="module")
async def llama_continue_final_message(llama_continue_final_message_handle):
await llama_continue_final_message_handle.health(300)
return llama_continue_final_message_handle.client


def test_llama_completion_single_prompt(
llama_continue_final_message, response_snapshot
):
response = requests.post(
f"{llama_continue_final_message.base_url}/v1/chat/completions",
json={
"model": "tgi",
"messages": [
{"role": "system", "content": "system message"},
{"role": "user", "content": "Which is bigger an elephant or a mouse?"},
],
"max_tokens": 30,
"stream": False,
"seed": 1337,
},
headers=llama_continue_final_message.headers,
stream=False,
)
response = response.json()
print(response)
assert len(response["choices"]) == 1
content = response["choices"][0]["message"]["content"]
assert (
content
== "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1"
)
assert response == response_snapshot


def test_llama_completion_single_prompt_continue(
llama_continue_final_message, response_snapshot
):
response = requests.post(
f"{llama_continue_final_message.base_url}/v1/chat/completions",
json={
"model": "tgi",
"messages": [
{"role": "system", "content": "system message"},
{"role": "user", "content": "Which is bigger an elephant or a mouse?"},
{
"role": "assistant",
"content": "the elephant, but have you heard about",
},
],
"max_tokens": 30,
"stream": False,
"seed": 1337,
},
headers=llama_continue_final_message.headers,
stream=False,
)
response = response.json()
print(response)
assert len(response["choices"]) == 1
content = response["choices"][0]["message"]["content"]
assert (
content
== " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds"
)
assert response == response_snapshot
24 changes: 21 additions & 3 deletions router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,34 @@ impl ChatTemplate {
};

let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();

self.template
let final_message = messages.last().cloned();
let mut rendered_template = self
.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools,
})
.map_err(InferError::TemplateError)
.map_err(InferError::TemplateError)?;

// if the last message is from the assistant, continue the generation prompt
rendered_template = match final_message {
Some(msg) if msg.role == "assistant" => {
match rendered_template.rfind(msg.content.as_str()) {
// implementation based on feature in transformers pipeline
// https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418
Some(index) => rendered_template[..index + msg.content.len()]
.trim_end()
.to_string(),
None => rendered_template,
}
}
_ => rendered_template,
};

Ok(rendered_template)
}
}

Expand Down

0 comments on commit d471805

Please sign in to comment.