diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml new file mode 100644 index 0000000000..59d667b8db --- /dev/null +++ b/examples/phi/lora-3.5.yaml @@ -0,0 +1,76 @@ +base_model: microsoft/Phi-3.5-mini-instruct +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: phi_3 +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + chat_template: phi_3 + field_messages: messages + message_field_role: role + message_field_content: content + roles: + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 2 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bfloat16: true +bf16: true +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 4 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 19e36531a5..717367eefa 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -24,8 +24,8 @@ def __init__( max_length=2048, message_field_role: str = "from", message_field_content: str = "value", - message_field_training: str = "train", - message_field_training_detail: str = "train_detail", + message_field_training: Optional[str] = None, + message_field_training_detail: Optional[str] = None, roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -186,7 +186,7 @@ def __init__( train_on_inputs, sequence_len, roles_to_train=None, - train_on_eos="last", + train_on_eos=None, ): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.roles_to_train = roles_to_train if roles_to_train is not None else [] @@ -201,6 +201,37 @@ def messages(self, messages): self._messages = messages def tokenize_prompt(self, prompt): + # Old simple legacy behavior that works reliably. + if ( + not self.roles_to_train + and not self.train_on_eos + and not self.prompter.message_field_training + and not self.prompter.message_field_training_detail + ): + turns = self.get_conversation_thread(prompt) + prompt_ids = self.prompter.build_prompt( + turns[:-1], add_generation_prompt=True + ) + input_ids = self.prompter.build_prompt(turns) + + if not self.train_on_inputs: + user_prompt_len = len(prompt_ids) + labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] + else: + labels = input_ids + + tokenized_prompt = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(input_ids), + } + + return tokenized_prompt + LOG.info(self.roles_to_train) + LOG.info(self.train_on_eos) + LOG.info(self.prompter.message_field_training) + LOG.info(self.prompter.message_field_training_detail) + turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) labels = [IGNORE_TOKEN_ID] * len(input_ids) @@ -219,9 +250,11 @@ def tokenize_prompt(self, prompt): should_train = ( train_turn if train_turn is not None - else bool(train_detail is not None) - if train_detail is not None - else self.train_on_inputs or role in self.roles_to_train + else ( + bool(train_detail is not None) + if train_detail is not None + else self.train_on_inputs or role in self.roles_to_train + ) ) LOG.debug(f"Should train: {should_train}") @@ -344,9 +377,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), "message_field_role": ds_cfg.get("message_field_role", "from"), "message_field_content": ds_cfg.get("message_field_content", "value"), - "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", "train_detail" + "message_field_training_detail", + None, ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), @@ -357,8 +391,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": ds_cfg.get("roles_to_train", []), + "train_on_eos": ds_cfg.get("train_on_eos", None), } strategy = ChatTemplateStrategy( diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 51f88b1bdf..7a96f5c1e1 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -26,6 +26,7 @@ def chat_templates(user_choice: str): "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 9044047cce..458bacdb12 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -189,6 +189,7 @@ class ChatTemplate(str, Enum): cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name + phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name jamba = "jamba" # pylint: disable=invalid-name diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py new file mode 100644 index 0000000000..43423f7255 --- /dev/null +++ b/tests/prompt_strategies/conftest.py @@ -0,0 +1,71 @@ +""" +shared fixtures for prompt strategies tests +""" + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + return Dataset.from_list( + [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "goodbye"}, + {"role": "assistant", "content": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="sharegpt_dataset") +def fixture_sharegpt_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "human", "value": "hello"}, + {"from": "gpt", "value": "hello"}, + {"from": "human", "value": "goodbye"}, + {"from": "gpt", "value": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="basic_dataset") +def fixture_basic_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "system", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "assistant", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "assistant", "value": "I'm doing well, thank you!"}, + ] + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") + + return tokenizer + + +@pytest.fixture(name="phi35_tokenizer") +def fixture_phi35_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") + return tokenizer diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index e2fc0f6a52..28210b7ae8 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -5,10 +5,6 @@ import logging import unittest -import pytest -from datasets import Dataset -from transformers import AutoTokenizer - from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, @@ -22,657 +18,6 @@ LOG = logging.getLogger("axolotl") -@pytest.fixture(name="assistant_dataset") -def fixture_assistant_dataset(): - return Dataset.from_list( - [ - { - "messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hello"}, - {"role": "user", "content": "goodbye"}, - {"role": "assistant", "content": "goodbye"}, - ] - } - ] - ) - - -@pytest.fixture(name="sharegpt_dataset") -def fixture_sharegpt_dataset(): - # pylint: disable=duplicate-code - return Dataset.from_list( - [ - { - "conversations": [ - {"from": "human", "value": "hello"}, - {"from": "gpt", "value": "hello"}, - {"from": "human", "value": "goodbye"}, - {"from": "gpt", "value": "goodbye"}, - ] - } - ] - ) - - -@pytest.fixture(name="basic_dataset") -def fixture_basic_dataset(): - # pylint: disable=duplicate-code - return Dataset.from_list( - [ - { - "conversations": [ - {"from": "system", "value": "You are an AI assistant."}, - {"from": "human", "value": "Hello"}, - {"from": "assistant", "value": "Hi there!"}, - {"from": "human", "value": "How are you?"}, - {"from": "assistant", "value": "I'm doing well, thank you!"}, - ] - } - ] - ) - - -@pytest.fixture(name="llama3_tokenizer") -def fixture_llama3_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") - - return tokenizer - - -class TestChatTemplateConfigurations: - """ - Test class for various configurations of ChatTemplateStrategy. - """ - - @staticmethod - def find_sublist(full_list, sub_list): - token_count = len(sub_list) - for index in range(len(full_list) - token_count + 1): - if full_list[index : index + token_count] == sub_list: - return index - return -1 - - def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_inputs=True") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=True, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - # Check the behavior of human inputs - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - labeled = all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(input_ids)] - ) - LOG.debug( - f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" - ) - - LOG.debug("Full labels: %s", labels) - LOG.debug("Full input_ids: %s", input_ids) - - def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_inputs=False") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that only assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - # Verify that human inputs are not labeled - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - LOG.debug( - f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{input_text}' in input_ids" - assert all( - label == IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(input_ids)] - ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" - - def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing roles_to_train with assistant only") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that only assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing roles_to_train with all roles") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=True, - sequence_len=512, - roles_to_train=["human", "assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that all responses are labeled (except for special tokens) - all_responses = [ - "Hello", - "Hi there!", - "How are you?", - "I'm doing well, thank you!", - ] - for response in all_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with empty roles_to_train") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=[], - train_on_eos="none", # Add this line - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - - # Verify that no labels are set when roles_to_train is empty - LOG.debug("Full labels: %s", labels) - assert all( - label == IGNORE_TOKEN_ID for label in labels - ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" - - def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='all'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="all", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - for eos_idx in eos_indices: - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to be labeled" - - def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='turn'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="turn", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - - eos_idx = start_idx + len(response_ids) - while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: - eos_idx += 1 - - assert eos_idx < len( - input_ids - ), f"Could not find EOS token after '{response}'" - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token after assistant response '{response}' to be labeled" - - # Check that EOS tokens after human inputs are not labeled - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - assert start_idx != -1, f"Could not find '{input_text}' in input_ids" - - eos_idx = start_idx + len(input_ids) - while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: - eos_idx += 1 - - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token after human input '{input_text}' to not be labeled" - - def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='last'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="last", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - last_eos_idx = eos_indices[-1] - - # Check that only the last EOS token is labeled - for idx in eos_indices[:-1]: - assert ( - labels[idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {idx} to not be labeled" - assert ( - labels[last_eos_idx] != IGNORE_TOKEN_ID - ), f"Expected last EOS token at index {last_eos_idx} to be labeled" - - def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='none'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="none", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - for eos_idx in eos_indices: - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to not be labeled" - - def test_drop_system_message(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with drop_system_message=True") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - input_ids = res["input_ids"] - - # Check if system message is not present in input_ids - system_message = "You are an AI assistant." - system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) - assert ( - self.find_sublist(input_ids, system_ids) == -1 - ), "Expected system message to be dropped" - - def test_custom_roles(self, llama3_tokenizer): - LOG.info("Testing with custom roles mapping") - custom_roles = { - "user": ["human", "user"], - "assistant": ["ai", "assistant"], - "system": ["context"], - } - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["ai"], - ) - - # Create a new dataset with modified role names - modified_conversations = [ - {"from": "context", "value": "You are an AI assistant."}, - {"from": "human", "value": "Hello"}, - {"from": "ai", "value": "Hi there!"}, - {"from": "human", "value": "How are you?"}, - {"from": "ai", "value": "I'm doing well, thank you!"}, - ] - - modified_dataset = Dataset.from_dict( - {"conversations": [modified_conversations]} - ) - - res = strategy.tokenize_prompt(modified_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Check if AI responses are labeled correctly - ai_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in ai_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - assert start_idx != -1, f"Could not find response '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for AI response '{response}' to be set" - - # Check if human messages are not labeled - human_messages = ["Hello", "How are you?"] - for message in human_messages: - message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, message_ids) - assert start_idx != -1, f"Could not find message '{message}' in input_ids" - assert all( - label == IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(message_ids)] - ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" - - def test_message_field_training(self, llama3_tokenizer): - LOG.info("Testing with message_field_training") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, - chat_templates("llama3"), - message_field_training="train", - message_field_training_detail="train_detail", - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=[], - ) - - # Create a new dataset with the train and train_detail fields - modified_conversation = [ - {"from": "system", "value": "You are an AI assistant.", "train": False}, - {"from": "human", "value": "Hello", "train": False}, - {"from": "assistant", "value": "Hello", "train": True}, - {"from": "human", "value": "How are you?", "train": True}, - { - "from": "assistant", - "value": "I'm doing very well, thank you!", - "train_detail": [ - {"begin_offset": 0, "end_offset": 8, "train": False}, - {"begin_offset": 9, "end_offset": 18, "train": True}, - {"begin_offset": 19, "end_offset": 30, "train": False}, - ], - }, - { - "from": "human", - "value": "I'm doing very well, thank you!", - "train": False, - }, - {"from": "assistant", "value": "Hi there!", "train": True}, - ] - - modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) - - res = strategy.tokenize_prompt(modified_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Function to find all occurrences of a sublist - def find_all_sublists(full_list, sub_list): - indices = [] - for index in range(len(full_list) - len(sub_list) + 1): - if full_list[index : index + len(sub_list)] == sub_list: - indices.append(index) - return indices - - # Keep track of which occurrences we've processed - processed_occurrences = {} - # Check if messages are labeled correctly based on train or train_detail - for i, turn in enumerate(modified_conversation): - turn_tokens = llama3_tokenizer.encode( - turn["value"], add_special_tokens=False - ) - occurrences = find_all_sublists(input_ids, turn_tokens) - turn_key = turn["value"] - if turn_key not in processed_occurrences: - processed_occurrences[turn_key] = 0 - current_occurrence = processed_occurrences[turn_key] - - if current_occurrence >= len(occurrences): - assert ( - False - ), f"Not enough occurrences found for message: {turn['value']}" - - start_idx = occurrences[current_occurrence] - processed_occurrences[turn_key] += 1 - end_idx = start_idx + len(turn_tokens) - - LOG.debug( - f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" - ) - - if "train_detail" in turn: - # Get token offsets - tokenized_output = llama3_tokenizer( - turn["value"], return_offsets_mapping=True, add_special_tokens=False - ) - token_offsets = tokenized_output["offset_mapping"] - - # Adjust token offsets as done in the implementation - for i in range(len(token_offsets) - 1): - token_offsets[i] = ( - token_offsets[i][0], - token_offsets[i + 1][0] - 1, - ) - token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) - - # Adjust train_details - adjusted_train_details = strategy.prompter.adjust_train_details( - turn["train_detail"], token_offsets - ) - - LOG.debug(f"Original train_details: {turn['train_detail']}") - LOG.debug(f"Adjusted train_details: {adjusted_train_details}") - - # Handle train_detail - token_offsets = strategy.prompter.get_offsets_for_train_detail( - text=turn["value"], - train_details=adjusted_train_details, - mask_untrainable=False, - ) - token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( - text=turn["value"], - train_details=adjusted_train_details, - mask_untrainable=True, - ) - LOG.debug(f"Token offsets: {token_offsets_masked}") - - expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) - for i, offset in enumerate(token_offsets_masked): - if offset != IGNORE_TOKEN_ID: - expected_labels[i] = turn_tokens[i] - actual_labels = labels[ - start_idx : start_idx + len(token_offsets_masked) - ] - assert ( - actual_labels == expected_labels - ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" - - for detail in adjusted_train_details: - # Find the token indices that correspond to the character offsets - detail_start = start_idx + next( - i - for i, offset in enumerate(token_offsets) - if offset >= detail["begin_offset"] - ) - detail_end = start_idx + next( - ( - i - for i, offset in enumerate(token_offsets) - if offset > detail["end_offset"] - ), - len(token_offsets), - ) - - detail_text = turn["value"][ - detail["begin_offset"] : detail["end_offset"] + 1 - ] - detail_labels = labels[detail_start:detail_end] - detail_input_ids = input_ids[detail_start:detail_end] - - LOG.debug( - f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" - ) - LOG.debug(f"Detail input_ids: {detail_input_ids}") - LOG.debug(f"Detail labels: {detail_labels}") - LOG.debug( - f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" - ) - LOG.debug( - f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" - ) - - if detail["train"]: - assert all( - label != IGNORE_TOKEN_ID for label in detail_labels - ), ( - f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " - f"Labels({detail_start}:{detail_end}): {detail_labels}, " - f"InputIDs: {detail_input_ids}, " - f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" - ) - else: - assert all( - label == IGNORE_TOKEN_ID for label in detail_labels - ), ( - f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " - f"Labels({detail_start}:{detail_end}): {detail_labels}, " - f"InputIDs: {detail_input_ids}, " - f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" - ) - else: - should_train = turn.get("train", False) - turn_labels = labels[start_idx:end_idx] - - LOG.debug(f"Should train: {should_train}") - LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") - LOG.debug(f"Turn labels: {turn_labels}") - LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") - LOG.debug( - f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" - ) - - if should_train: - assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( - f"Expected all labels for '{turn['value']}' to be set\n" - f"Labels({start_idx}:{end_idx}): {turn_labels}, " - f"InputIDs: {input_ids[start_idx:end_idx]}, " - f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" - ) - else: - assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( - f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" - f"Labels({start_idx}:{end_idx}): {turn_labels}, " - f"InputIDs: {input_ids[start_idx:end_idx]}, " - f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" - ) - - LOG.debug( - f"Processed turn: {turn['from']}, content: '{turn['value']}', " - f"start_idx: {start_idx}, end_idx: {end_idx}, " - f"labels: {labels[start_idx:end_idx]}" - ) - - LOG.debug(f"Final labels: {labels}") - LOG.debug(f"Final input_ids: {input_ids}") - - class TestAssistantChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -740,7 +85,6 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, - roles_to_train=["assistant"], ) strategy.messages = "messages" res = strategy.tokenize_prompt(assistant_dataset[0]) @@ -764,6 +108,64 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): input_ids == expected_input_ids ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + def test_phi35(self, phi35_tokenizer, assistant_dataset): + LOG.info("Testing phi-3.5 with assistant dataset") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + phi35_tokenizer, + chat_templates("phi_35"), + message_field_role="role", + message_field_content="content", + roles={ + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ), + tokenizer=phi35_tokenizer, + train_on_inputs=False, + sequence_len=512, + ) + strategy.messages = "messages" + res = strategy.tokenize_prompt(assistant_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 32010, # user + 22172, 32007, # user eot + 32001, # assistant + 22172, 32007, # assistant eot + 32010, # user + 1781, 26966, 32007, # user eot + 32001, # assistant + 1781, 26966, 32007, # assistant eot + 32000, # eos + ] + expected_labels = [ + -100, # user + -100, -100, # user eot + -100, # assistant + -100, -100, # assistant eot, + -100, # user + -100, -100, -100, # user eot + -100, # assistant + 1781, 26966, 32007, # assistant eot + 32000, # eos + ] + # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + LOG.debug(f"Expected labels : {expected_labels}") + LOG.debug(f"Actual labels : {labels}") + assert ( + labels == expected_labels + ), f"Input IDs mismatch: {labels} != {expected_labels}" + def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): LOG.info("Testing llama-3 with assistant dataset including training data") strategy = ChatTemplateStrategy( diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py new file mode 100644 index 0000000000..f18fb39423 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -0,0 +1,615 @@ +""" +tests for chat_template prompt strategy +""" + +import logging +import unittest + +from datasets import Dataset + +from axolotl.prompt_strategies.chat_template import ( + ChatTemplatePrompter, + ChatTemplateStrategy, +) +from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.utils.chat_templates import chat_templates + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +class TestChatTemplateConfigurations: + """ + Test class for various configurations of ChatTemplateStrategy. + """ + + @staticmethod + def find_sublist(full_list, sub_list): + token_count = len(sub_list) + for index in range(len(full_list) - token_count + 1): + if full_list[index : index + token_count] == sub_list: + return index + return -1 + + def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Check the behavior of human inputs + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + labeled = all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ) + LOG.debug( + f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" + ) + + LOG.debug("Full labels: %s", labels) + LOG.debug("Full input_ids: %s", input_ids) + + def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=False") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Verify that human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + LOG.debug( + f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" + + def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with assistant only") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with all roles") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["human", "assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that all responses are labeled (except for special tokens) + all_responses = [ + "Hello", + "Hi there!", + "How are you?", + "I'm doing well, thank you!", + ] + for response in all_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with empty roles_to_train") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + train_on_eos="none", # Add this line + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + + # Verify that no labels are set when roles_to_train is empty + LOG.debug("Full labels: %s", labels) + assert all( + label == IGNORE_TOKEN_ID for label in labels + ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + + def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='all'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="all", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to be labeled" + + def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='turn'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="turn", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + + eos_idx = start_idx + len(response_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert eos_idx < len( + input_ids + ), f"Could not find EOS token after '{response}'" + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token after assistant response '{response}' to be labeled" + + # Check that EOS tokens after human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + + eos_idx = start_idx + len(input_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token after human input '{input_text}' to not be labeled" + + def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='last'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="last", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + last_eos_idx = eos_indices[-1] + + # Check that only the last EOS token is labeled + for idx in eos_indices[:-1]: + assert ( + labels[idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {idx} to not be labeled" + assert ( + labels[last_eos_idx] != IGNORE_TOKEN_ID + ), f"Expected last EOS token at index {last_eos_idx} to be labeled" + + def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='none'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="none", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to not be labeled" + + def test_drop_system_message(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with drop_system_message=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + + # Check if system message is not present in input_ids + system_message = "You are an AI assistant." + system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) + assert ( + self.find_sublist(input_ids, system_ids) == -1 + ), "Expected system message to be dropped" + + def test_custom_roles(self, llama3_tokenizer): + LOG.info("Testing with custom roles mapping") + custom_roles = { + "user": ["human", "user"], + "assistant": ["ai", "assistant"], + "system": ["context"], + } + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["ai"], + ) + + # Create a new dataset with modified role names + modified_conversations = [ + {"from": "context", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "ai", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "ai", "value": "I'm doing well, thank you!"}, + ] + + modified_dataset = Dataset.from_dict( + {"conversations": [modified_conversations]} + ) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Check if AI responses are labeled correctly + ai_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in ai_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find response '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for AI response '{response}' to be set" + + # Check if human messages are not labeled + human_messages = ["Hello", "How are you?"] + for message in human_messages: + message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, message_ids) + assert start_idx != -1, f"Could not find message '{message}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(message_ids)] + ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" + + def test_message_field_training(self, llama3_tokenizer): + LOG.info("Testing with message_field_training") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_training="train", + message_field_training_detail="train_detail", + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + ) + + # Create a new dataset with the train and train_detail fields + modified_conversation = [ + {"from": "system", "value": "You are an AI assistant.", "train": False}, + {"from": "human", "value": "Hello", "train": False}, + {"from": "assistant", "value": "Hello", "train": True}, + {"from": "human", "value": "How are you?", "train": True}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": False}, + {"begin_offset": 9, "end_offset": 18, "train": True}, + {"begin_offset": 19, "end_offset": 30, "train": False}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": False, + }, + {"from": "assistant", "value": "Hi there!", "train": True}, + ] + + modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Function to find all occurrences of a sublist + def find_all_sublists(full_list, sub_list): + indices = [] + for index in range(len(full_list) - len(sub_list) + 1): + if full_list[index : index + len(sub_list)] == sub_list: + indices.append(index) + return indices + + # Keep track of which occurrences we've processed + processed_occurrences = {} + # Check if messages are labeled correctly based on train or train_detail + for i, turn in enumerate(modified_conversation): + turn_tokens = llama3_tokenizer.encode( + turn["value"], add_special_tokens=False + ) + occurrences = find_all_sublists(input_ids, turn_tokens) + turn_key = turn["value"] + if turn_key not in processed_occurrences: + processed_occurrences[turn_key] = 0 + current_occurrence = processed_occurrences[turn_key] + + if current_occurrence >= len(occurrences): + assert ( + False + ), f"Not enough occurrences found for message: {turn['value']}" + + start_idx = occurrences[current_occurrence] + processed_occurrences[turn_key] += 1 + end_idx = start_idx + len(turn_tokens) + + LOG.debug( + f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" + ) + + if "train_detail" in turn: + # Get token offsets + tokenized_output = llama3_tokenizer( + turn["value"], return_offsets_mapping=True, add_special_tokens=False + ) + token_offsets = tokenized_output["offset_mapping"] + + # Adjust token offsets as done in the implementation + for i in range(len(token_offsets) - 1): + token_offsets[i] = ( + token_offsets[i][0], + token_offsets[i + 1][0] - 1, + ) + token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) + + # Adjust train_details + adjusted_train_details = strategy.prompter.adjust_train_details( + turn["train_detail"], token_offsets + ) + + LOG.debug(f"Original train_details: {turn['train_detail']}") + LOG.debug(f"Adjusted train_details: {adjusted_train_details}") + + # Handle train_detail + token_offsets = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=False, + ) + token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=True, + ) + LOG.debug(f"Token offsets: {token_offsets_masked}") + + expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) + for i, offset in enumerate(token_offsets_masked): + if offset != IGNORE_TOKEN_ID: + expected_labels[i] = turn_tokens[i] + actual_labels = labels[ + start_idx : start_idx + len(token_offsets_masked) + ] + assert ( + actual_labels == expected_labels + ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + + for detail in adjusted_train_details: + # Find the token indices that correspond to the character offsets + detail_start = start_idx + next( + i + for i, offset in enumerate(token_offsets) + if offset >= detail["begin_offset"] + ) + detail_end = start_idx + next( + ( + i + for i, offset in enumerate(token_offsets) + if offset > detail["end_offset"] + ), + len(token_offsets), + ) + + detail_text = turn["value"][ + detail["begin_offset"] : detail["end_offset"] + 1 + ] + detail_labels = labels[detail_start:detail_end] + detail_input_ids = input_ids[detail_start:detail_end] + + LOG.debug( + f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" + ) + LOG.debug(f"Detail input_ids: {detail_input_ids}") + LOG.debug(f"Detail labels: {detail_labels}") + LOG.debug( + f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" + ) + LOG.debug( + f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" + ) + + if detail["train"]: + assert all( + label != IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + assert all( + label == IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + should_train = turn.get("train", False) + turn_labels = labels[start_idx:end_idx] + + LOG.debug(f"Should train: {should_train}") + LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") + LOG.debug(f"Turn labels: {turn_labels}") + LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") + LOG.debug( + f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" + ) + + if should_train: + assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be set\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + else: + assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + + LOG.debug( + f"Processed turn: {turn['from']}, content: '{turn['value']}', " + f"start_idx: {start_idx}, end_idx: {end_idx}, " + f"labels: {labels[start_idx:end_idx]}" + ) + + LOG.debug(f"Final labels: {labels}") + LOG.debug(f"Final input_ids: {input_ids}") + + +if __name__ == "__main__": + unittest.main()