Skip to content

Commit

Permalink
revert system prompt changes for chat templtes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 18, 2024
1 parent 01e5ece commit 0c7ac28
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
15 changes: 15 additions & 0 deletions docs/rlhf.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ datasets:
rl: ipo
```
#### ORPO
Paper: https://arxiv.org/abs/2403.07691
```yaml
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
```
#### Using local dataset files
```yaml
datasets:
Expand Down
6 changes: 2 additions & 4 deletions src/axolotl/prompt_strategies/orpo/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@ def load(
chatml transforms for datasets with system, input, chosen, rejected
"""

chat_template = chat_templates("chatml", system_message=cfg.default_system_message)
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(
chat_template, system_message=cfg.default_system_message
)
chat_template = chat_templates(chat_template)
except ValueError:
pass

Expand Down
15 changes: 3 additions & 12 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
from typing import Optional


def chat_templates(user_choice: str, system_message: Optional[str] = None):
def chat_templates(user_choice: str):
"""
Finds the correct chat_template for the tokenizer_config.
Expand All @@ -19,22 +18,14 @@ def chat_templates(user_choice: str, system_message: Optional[str] = None):
ValueError: If the user_choice is not found in the templates.
"""

default_system_message = "You are a helpful assistant."
template_system_message: str = (
system_message if system_message is not None else default_system_message
)

templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 and system_message %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
}

if user_choice in templates:
template = templates[user_choice]
if default_system_message in template:
template = template.replace(default_system_message, template_system_message)
return template
return templates[user_choice]

raise ValueError(f"Template '{user_choice}' not found.")

0 comments on commit 0c7ac28

Please sign in to comment.