diff --git a/docs/rlhf.md b/docs/rlhf.md index 9f5ba05fdb..4f71184fc0 100644 --- a/docs/rlhf.md +++ b/docs/rlhf.md @@ -34,6 +34,21 @@ datasets: rl: ipo ``` +#### ORPO + +Paper: https://arxiv.org/abs/2403.07691 + +```yaml +rl: orpo +orpo_alpha: 0.1 +remove_unused_columns: false + +chat_template: chatml +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned + type: orpo.chat_template +``` + #### Using local dataset files ```yaml datasets: diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index 7455ce1069..26a05ce35e 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -29,13 +29,11 @@ def load( chatml transforms for datasets with system, input, chosen, rejected """ - chat_template = chat_templates("chatml", system_message=cfg.default_system_message) + chat_template = chat_templates("chatml") if ds_cfg and "chat_template" in ds_cfg: chat_template = ds_cfg["chat_template"] try: - chat_template = chat_templates( - chat_template, system_message=cfg.default_system_message - ) + chat_template = chat_templates(chat_template) except ValueError: pass diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 26af42e19d..fd34b4ea99 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -2,10 +2,9 @@ This module provides functionality for selecting chat templates based on user choices. These templates are used for formatting messages in a conversation. """ -from typing import Optional -def chat_templates(user_choice: str, system_message: Optional[str] = None): +def chat_templates(user_choice: str): """ Finds the correct chat_template for the tokenizer_config. @@ -19,22 +18,14 @@ def chat_templates(user_choice: str, system_message: Optional[str] = None): ValueError: If the user_choice is not found in the templates. """ - default_system_message = "You are a helpful assistant." - template_system_message: str = ( - system_message if system_message is not None else default_system_message - ) - templates = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 and system_message %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", } if user_choice in templates: - template = templates[user_choice] - if default_system_message in template: - template = template.replace(default_system_message, template_system_message) - return template + return templates[user_choice] raise ValueError(f"Template '{user_choice}' not found.")