diff --git a/src/axolotl/prompt_strategies/orpo/chatml.py b/src/axolotl/prompt_strategies/orpo/chat_template.py similarity index 96% rename from src/axolotl/prompt_strategies/orpo/chatml.py rename to src/axolotl/prompt_strategies/orpo/chat_template.py index 91be41cbb2..7455ce1069 100644 --- a/src/axolotl/prompt_strategies/orpo/chatml.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -29,11 +29,13 @@ def load( chatml transforms for datasets with system, input, chosen, rejected """ - chat_template = "chatml" + chat_template = chat_templates("chatml", system_message=cfg.default_system_message) if ds_cfg and "chat_template" in ds_cfg: chat_template = ds_cfg["chat_template"] try: - chat_template = chat_templates(chat_template) + chat_template = chat_templates( + chat_template, system_message=cfg.default_system_message + ) except ValueError: pass diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 89ed680a91..26af42e19d 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -2,9 +2,10 @@ This module provides functionality for selecting chat templates based on user choices. These templates are used for formatting messages in a conversation. """ +from typing import Optional -def chat_templates(user_choice: str): +def chat_templates(user_choice: str, system_message: Optional[str] = None): """ Finds the correct chat_template for the tokenizer_config. @@ -18,6 +19,11 @@ def chat_templates(user_choice: str): ValueError: If the user_choice is not found in the templates. """ + default_system_message = "You are a helpful assistant." + template_system_message: str = ( + system_message if system_message is not None else default_system_message + ) + templates = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. @@ -26,6 +32,9 @@ def chat_templates(user_choice: str): } if user_choice in templates: - return templates[user_choice] + template = templates[user_choice] + if default_system_message in template: + template = template.replace(default_system_message, template_system_message) + return template raise ValueError(f"Template '{user_choice}' not found.") diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 9151f288a8..3e743bda9f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg): f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template" ) cfg.datasets[idx].conversation = "chatml" + if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template: + LOG.info( + f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template" + ) + cfg.datasets[idx].chat_template = "chatml" def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index c7dd87b24b..4e659006fe 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -20,7 +20,7 @@ Llama2ChatPrompter, LLama2ChatTokenizingStrategy, ) -from axolotl.prompt_strategies.orpo.chatml import load +from axolotl.prompt_strategies.orpo.chat_template import load from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy,