Skip to content

Commit

Permalink
Adds chat templates (#1022)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhenrichsen authored Dec 29, 2023
1 parent 11fea55 commit 783eef7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,9 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:

# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
Expand Down
29 changes: 29 additions & 0 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""


def chat_templates(user_choice: str):
"""
Finds the correct chat_template for the tokenizer_config.
Args:
user_choice (str): The user's choice of template.
Returns:
str: The chosen template string.
Raises:
ValueError: If the user_choice is not found in the templates.
"""

templates = {
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
}

if user_choice in templates:
return templates[user_choice]

raise ValueError(f"Template '{user_choice}' not found.")
7 changes: 7 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl")
Expand Down Expand Up @@ -186,6 +187,12 @@ def load_tokenizer(cfg):
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")

if cfg.chat_template:
tokenizer.chat_template = chat_templates(cfg.chat_template)
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer


Expand Down

0 comments on commit 783eef7

Please sign in to comment.