Skip to content

Commit

Permalink
better handling of system message for orpo
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 18, 2024
1 parent 556beca commit 01e5ece
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def load(
chatml transforms for datasets with system, input, chosen, rejected
"""

chat_template = "chatml"
chat_template = chat_templates("chatml", system_message=cfg.default_system_message)
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
chat_template = chat_templates(
chat_template, system_message=cfg.default_system_message
)
except ValueError:
pass

Expand Down
13 changes: 11 additions & 2 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
from typing import Optional


def chat_templates(user_choice: str):
def chat_templates(user_choice: str, system_message: Optional[str] = None):
"""
Finds the correct chat_template for the tokenizer_config.
Expand All @@ -18,6 +19,11 @@ def chat_templates(user_choice: str):
ValueError: If the user_choice is not found in the templates.
"""

default_system_message = "You are a helpful assistant."
template_system_message: str = (
system_message if system_message is not None else default_system_message
)

templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
Expand All @@ -26,6 +32,9 @@ def chat_templates(user_choice: str):
}

if user_choice in templates:
return templates[user_choice]
template = templates[user_choice]
if default_system_message in template:
template = template.replace(default_system_message, template_system_message)
return template

raise ValueError(f"Template '{user_choice}' not found.")
5 changes: 5 additions & 0 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
)
cfg.datasets[idx].conversation = "chatml"
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
LOG.info(
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
)
cfg.datasets[idx].chat_template = "chatml"


def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_strategies.orpo.chatml import load
from axolotl.prompt_strategies.orpo.chat_template import load
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
Expand Down

0 comments on commit 01e5ece

Please sign in to comment.