Skip to content

Commit

Permalink
Pin Chat format to TiktokenTokenizerWrapper (#752)
Browse files Browse the repository at this point in the history
* pin default chat template

* cleanup

* cleanup

* cleanup

* chat template as a property

* learning to write jinja

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
rajammanabrolu and dakinggg authored Nov 21, 2023
1 parent 6dc94a2 commit e191b05
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
27 changes: 27 additions & 0 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from transformers import PreTrainedTokenizer

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible."""


class TiktokenTokenizerWrapper(PreTrainedTokenizer):
"""A thin wrapper around tiktoken to make it compatible with Hugging Face.
Expand All @@ -23,6 +25,7 @@ def __init__(self,
encoding_name: Optional[str] = None,
add_bos_token: bool = False,
add_eos_token: bool = False,
use_default_system_prompt: bool = False,
unk_token: Optional[str] = '<|endoftext|>',
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
Expand All @@ -39,6 +42,7 @@ def __init__(self,
Either model_name or encoding_name must be set, but not both.
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
use_default_system_prompt (bool, optional): Use the default system prompt or not. Defaults to False.
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
Expand Down Expand Up @@ -87,11 +91,13 @@ def pickle_Encoding(enc: Encoding):

self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt

super().__init__(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
unk_token=unk_token,
eos_token=eos_token,
bos_token=bos_token,
Expand All @@ -107,6 +113,27 @@ def vocab_size(self) -> int:
def is_fast(self) -> bool:
return False

@property
def default_chat_template(self):
"""Chat ML Template for User/Assistant.
Pinning default Chat ML template in case defaults change.
"""
template = (
"{% set system_message = '' %}"
'{% if USE_DEFAULT_PROMPT == true %}'
"{{'<|im_start|>system\n' + 'DEFAULT_SYSTEM_PROMPT'}}"
'{% endif %}'
'{% for message in messages %}'
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
'{% endfor %}')
template = template.replace(
'USE_DEFAULT_PROMPT',
'true' if self.use_default_system_prompt else 'false')
template = template.replace('DEFAULT_SYSTEM_PROMPT',
DEFAULT_SYSTEM_PROMPT)
return template

def get_vocab(self) -> Dict[str, int]:
"""Returns vocab as a dict.
Expand Down
57 changes: 57 additions & 0 deletions tests/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,34 @@
(None, 'cl100k_base'),
]

MULTI_TURN_CHAT_ML = [[{
'content':
'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.',
'role':
'user'
}, {
'content': 'You should go outside and touch grass.',
'role': 'assistant'
}]]

MULTI_TURN_CHAT_STRING = [
"""<|im_start|>user
Please summarize the goals in this text:
Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|>
<|im_start|>assistant
You should go outside and touch grass.<|im_end|>
"""
]

DEFAULT_SYSTEM_PROMPT = """<|im_start|>system\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible."""


def get_tokenizers_for_testing(
model_name: Optional[str],
encoding_name: Optional[str],
tmp_path: pathlib.Path,
use_default_system_prompt: bool = False,
add_bos_token: bool = False,
add_eos_token: bool = False,
additional_special_tokens: Optional[List[str]] = None,
Expand All @@ -60,6 +83,7 @@ def get_tokenizers_for_testing(
encoding_name=encoding_name,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
additional_special_tokens=additional_special_tokens)
if model_name is not None:
original_tokenizer = tiktoken.encoding_for_model(model_name)
Expand Down Expand Up @@ -259,3 +283,36 @@ def test_additional_special_tokens(model_name: Optional[str],

assert encoded_outputs[0] == wrapped_tokenizer.vocab_size
assert len(encoded_outputs) == 2


@pytest.mark.parametrize('model_name,encoding_name',
MODEL_ENCODING_NAME_PARAMETRIZATION)
def test_chat_formatting(model_name: Optional[str],
encoding_name: Optional[str], tmp_path: pathlib.Path):
special_tokens_to_add = ['<|im_start|>', '<im_end>']
# Default behavior to not use default system prompt.
wrapped_tokenizer, _, _ = get_tokenizers_for_testing(
model_name,
encoding_name,
tmp_path,
add_bos_token=False,
add_eos_token=False,
additional_special_tokens=special_tokens_to_add)
for i, dict_chats in enumerate(MULTI_TURN_CHAT_ML):
chat_str = wrapped_tokenizer.apply_chat_template(dict_chats,
tokenize=False)
assert chat_str == MULTI_TURN_CHAT_STRING[i]

# Using default system prompt.
wrapped_tokenizer, _, _ = get_tokenizers_for_testing(
model_name,
encoding_name,
tmp_path,
use_default_system_prompt=True,
add_bos_token=False,
add_eos_token=False,
additional_special_tokens=special_tokens_to_add)
for i, dict_chats in enumerate(MULTI_TURN_CHAT_ML):
chat_str = wrapped_tokenizer.apply_chat_template(dict_chats,
tokenize=False)
assert chat_str == DEFAULT_SYSTEM_PROMPT + MULTI_TURN_CHAT_STRING[i]

0 comments on commit e191b05

Please sign in to comment.