Skip to content

Commit

Permalink
Add assistant prefill for chat templates and TextGenerationPipeline (h…
Browse files Browse the repository at this point in the history
…uggingface#33198)

* Add assistant prefill to chat templates

* Add assistant prefill to pipeline

* Add assistant prefill to pipeline

* Tweak another test that ended in assistant message

* Update tests that ended in assistant messages

* Update tests that ended in assistant messages

* Replace assistant_prefill with continue_final_message

* Allow passing continue_final_message to pipeline

* Small fixup

* Add continue_final_message as a pipeline kwarg

* Update docstrings

* Move repos to hf-internal-testing!

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <[email protected]>

* Add explanatory comment

* make fixup

* Update chat templating docs to explain continue_last_message

---------

Co-authored-by: Lysandre Debut <[email protected]>
  • Loading branch information
Rocketknight1 and LysandreJik authored Sep 2, 2024
1 parent 2d37085 commit 52a0213
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 23 deletions.
37 changes: 37 additions & 0 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,43 @@ Not all models require generation prompts. Some models, like BlenderBot and LLaM
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
effect that `add_generation_prompt` has will depend on the template being used.

## What does "continue_last_message" do?

When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.

Here's an example:

```python
chat = [
{"role": "user", "content": "Can you format the answer in JSON?"},
{"role": "assistant", "content": '{"name": "'},
]

formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_last_message=True)
model.generate(**formatted_chat)
```

The model will generate text that continues the JSON string, rather than starting a new message. This approach
can be very useful for improving the accuracy of the model's instruction-following when you know how you want
it to start its replies.

Because `add_generation_prompt` adds the tokens that start a new message, and `continue_last_message` removes any
end-of-message tokens from the final message, it does not make sense to use them together. As a result, you'll
get an error if you try!

<Tip>

The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_last_message`
argument when calling the pipeline.

</Tip>

## Can I use chat templates in training?

Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
Expand Down
44 changes: 39 additions & 5 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _sanitize_parameters(
stop_sequence=None,
truncation=None,
max_length=None,
continue_final_message=None,
**generate_kwargs,
):
preprocess_params = {}
Expand Down Expand Up @@ -165,6 +166,9 @@ def _sanitize_parameters(
)
preprocess_params["handle_long_generation"] = handle_long_generation

if continue_final_message is not None:
preprocess_params["continue_final_message"] = continue_final_message

preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs

Expand All @@ -183,6 +187,8 @@ def _sanitize_parameters(
postprocess_params["return_type"] = return_type
if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message

if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
Expand Down Expand Up @@ -226,6 +232,10 @@ def __call__(self, text_inputs, **kwargs):
*return_text* is set to True.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the potential extra spaces in the text output.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and
`False` otherwise, but you can manually override that behaviour by setting this flag.
prefix (`str`, *optional*):
Prefix added to prompt.
handle_long_generation (`str`, *optional*):
Expand Down Expand Up @@ -270,6 +280,7 @@ def preprocess(
truncation=None,
padding=None,
max_length=None,
continue_final_message=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
Expand All @@ -283,9 +294,14 @@ def preprocess(

if isinstance(prompt_text, Chat):
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
add_generation_prompt=True,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_dict=True,
return_tensors=self.framework,
**tokenizer_kwargs,
Expand Down Expand Up @@ -356,7 +372,13 @@ def _forward(self, model_inputs, **generate_kwargs):
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}

def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
continue_final_message=None,
):
generated_sequence = model_outputs["generated_sequence"][0]
input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"]
Expand Down Expand Up @@ -390,9 +412,21 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_
if isinstance(prompt_text, str):
all_text = prompt_text + all_text
elif isinstance(prompt_text, Chat):
# Explicit list parsing is necessary for parsing chat datasets
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]

if continue_final_message is None:
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
# default because very few models support multiple separate, consecutive assistant messages
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
if continue_final_message:
# With assistant prefill, concat onto the end of the last message
all_text = list(prompt_text.messages)[:-1] + [
{
"role": prompt_text.messages[-1]["role"],
"content": prompt_text.messages[-1]["content"] + all_text,
}
]
else:
# When we're not starting from a prefill, the output is a new assistant message
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
record = {"generated_text": all_text}
records.append(record)

Expand Down
22 changes: 20 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,7 @@ def apply_chat_template(
documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
continue_final_message: bool = False,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
Expand Down Expand Up @@ -1737,10 +1738,16 @@ def apply_chat_template(
chat_template (`str`, *optional*):
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
argument, as the model's template will be used by default.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
add_generation_prompt (bool, *optional*):
If this is set, a prompt with the token(s) that indicate
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect.
continue_final_message (bool, *optional*):
If this is set, the chat will be formatted so that the final
message in the chat is open-ended, without any EOS tokens. The model will continue this message
rather than starting a new one. This allows you to "prefill" part of
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
tokenize (`bool`, defaults to `True`):
Whether to tokenize the output. If `False`, the output will be a string.
padding (`bool`, defaults to `False`):
Expand Down Expand Up @@ -1803,6 +1810,14 @@ def apply_chat_template(
conversations = [conversation]
is_batched = False

if continue_final_message:
if add_generation_prompt:
raise ValueError(
"continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
)
if return_assistant_tokens_mask:
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")

# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
if tools is not None:
tool_schemas = []
Expand Down Expand Up @@ -1849,6 +1864,9 @@ def apply_chat_template(
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
if continue_final_message:
final_message = chat[-1]["content"]
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
rendered.append(rendered_chat)

if not is_batched:
Expand Down
12 changes: 6 additions & 6 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,8 @@ def test_custom_code_with_string_tokenizer(self):
# See https://github.com/huggingface/transformers/issues/31669
text_generator = pipeline(
"text-generation",
model="Rocketknight1/fake-custom-model-test",
tokenizer="Rocketknight1/fake-custom-model-test",
model="hf-internal-testing/tiny-random-custom-architecture",
tokenizer="hf-internal-testing/tiny-random-custom-architecture",
trust_remote_code=True,
)

Expand All @@ -888,8 +888,8 @@ def test_custom_code_with_string_tokenizer(self):
def test_custom_code_with_string_feature_extractor(self):
speech_recognizer = pipeline(
"automatic-speech-recognition",
model="Rocketknight1/fake-custom-wav2vec2",
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
model="hf-internal-testing/fake-custom-wav2vec2",
feature_extractor="hf-internal-testing/fake-custom-wav2vec2",
trust_remote_code=True,
)

Expand All @@ -899,8 +899,8 @@ def test_custom_code_with_string_feature_extractor(self):
def test_custom_code_with_string_preprocessor(self):
mask_generator = pipeline(
"mask-generation",
model="Rocketknight1/fake-custom-sam",
processor="Rocketknight1/fake-custom-sam",
model="hf-internal-testing/fake-custom-sam",
processor="hf-internal-testing/fake-custom-sam",
trust_remote_code=True,
)

Expand Down
77 changes: 67 additions & 10 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,16 @@ def test_small_model_pt(self):
@require_torch
def test_small_chat_model_pt(self):
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [
Expand All @@ -179,7 +177,7 @@ def test_small_chat_model_pt(self):
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
}
]

Expand All @@ -191,6 +189,68 @@ def test_small_chat_model_pt(self):
],
)

@require_torch
def test_small_chat_model_continue_final_message(self):
# Here we check that passing a chat that ends in an assistant message is handled correctly
# by continuing the final message rather than starting a new one
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)

# Assert that we continued the last message and there isn't a sneaky <|im_end|>
self.assertEqual(
outputs,
[
{
"generated_text": [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{
"role": "assistant",
"content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
},
]
}
],
)

@require_torch
def test_small_chat_model_continue_final_message_override(self):
# Here we check that passing a chat that ends in an assistant message is handled correctly
# by continuing the final message rather than starting a new one
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)

# Assert that we continued the last message and there isn't a sneaky <|im_end|>
self.assertEqual(
outputs,
[
{
"generated_text": [
{"role": "system", "content": "This is a system message."},
{
"role": "user",
"content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
},
]
}
],
)

@require_torch
def test_small_chat_model_with_dataset_pt(self):
from torch.utils.data import Dataset
Expand All @@ -202,7 +262,6 @@ class MyDataset(Dataset):
[
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
],
]

Expand All @@ -213,7 +272,7 @@ def __getitem__(self, i):
return {"text": self.data[i]}

text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)

dataset = MyDataset()
Expand Down Expand Up @@ -277,18 +336,16 @@ def test_small_model_tf(self):
@require_tf
def test_small_chat_model_tf(self):
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [
Expand All @@ -308,7 +365,7 @@ def test_small_chat_model_tf(self):
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
}
]

Expand Down
Loading

0 comments on commit 52a0213

Please sign in to comment.