From d47180513468d4f8f7737b523664cfec28a42716 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 27 Nov 2024 19:13:30 -0500 Subject: [PATCH] Support continue final message (#2733) * feat: support continue_final_message param in chat request * feat: add test for continue final message * fix: bump openapi docs * fix: remove continue_final_message chat request param * fix: remove unneeded launcher args in continue test * fix: bump test output * fix: remove accidentally included guideline from rebase * fix: remove guideline tests * fix: adjust continuation tests expected text * fix: replace expected output for continue test --- .../test_llama_completion_single_prompt.json | 23 ++++++ ...ama_completion_single_prompt_continue.json | 23 ++++++ .../models/test_continue_final_message.py | 76 +++++++++++++++++++ router/src/infer/chat_template.rs | 24 +++++- 4 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json create mode 100644 integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json create mode 100644 integration-tests/models/test_continue_final_message.py diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json new file mode 100644 index 00000000000..0bea487c350 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1", + "role": "assistant" + } + } + ], + "created": 1732541189, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 30, + "prompt_tokens": 49, + "total_tokens": 79 + } +} diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json new file mode 100644 index 00000000000..100fb3385e4 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds", + "role": "assistant" + } + } + ], + "created": 1732541190, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 30, + "prompt_tokens": 73, + "total_tokens": 103 + } +} diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py new file mode 100644 index 00000000000..01c86dcd104 --- /dev/null +++ b/integration-tests/models/test_continue_final_message.py @@ -0,0 +1,76 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def llama_continue_final_message_handle(launcher): + with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def llama_continue_final_message(llama_continue_final_message_handle): + await llama_continue_final_message_handle.health(300) + return llama_continue_final_message_handle.client + + +def test_llama_completion_single_prompt( + llama_continue_final_message, response_snapshot +): + response = requests.post( + f"{llama_continue_final_message.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "Which is bigger an elephant or a mouse?"}, + ], + "max_tokens": 30, + "stream": False, + "seed": 1337, + }, + headers=llama_continue_final_message.headers, + stream=False, + ) + response = response.json() + print(response) + assert len(response["choices"]) == 1 + content = response["choices"][0]["message"]["content"] + assert ( + content + == "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1" + ) + assert response == response_snapshot + + +def test_llama_completion_single_prompt_continue( + llama_continue_final_message, response_snapshot +): + response = requests.post( + f"{llama_continue_final_message.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "Which is bigger an elephant or a mouse?"}, + { + "role": "assistant", + "content": "the elephant, but have you heard about", + }, + ], + "max_tokens": 30, + "stream": False, + "seed": 1337, + }, + headers=llama_continue_final_message.headers, + stream=False, + ) + response = response.json() + print(response) + assert len(response["choices"]) == 1 + content = response["choices"][0]["message"]["content"] + assert ( + content + == " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds" + ) + assert response == response_snapshot diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index f5f1dbcaddc..2bda71933ab 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -75,8 +75,9 @@ impl ChatTemplate { }; let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template + let final_message = messages.last().cloned(); + let mut rendered_template = self + .template .render(ChatTemplateInputs { messages, bos_token: self.bos_token.as_deref(), @@ -84,7 +85,24 @@ impl ChatTemplate { add_generation_prompt: true, tools, }) - .map_err(InferError::TemplateError) + .map_err(InferError::TemplateError)?; + + // if the last message is from the assistant, continue the generation prompt + rendered_template = match final_message { + Some(msg) if msg.role == "assistant" => { + match rendered_template.rfind(msg.content.as_str()) { + // implementation based on feature in transformers pipeline + // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418 + Some(index) => rendered_template[..index + msg.content.len()] + .trim_end() + .to_string(), + None => rendered_template, + } + } + _ => rendered_template, + }; + + Ok(rendered_template) } }