diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml
new file mode 100644
index 0000000000..59d667b8db
--- /dev/null
+++ b/examples/phi/lora-3.5.yaml
@@ -0,0 +1,76 @@
+base_model: microsoft/Phi-3.5-mini-instruct
+model_type: AutoModelForCausalLM
+tokenizer_type: AutoTokenizer
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+chat_template: phi_3
+datasets:
+ - path: fozziethebeat/alpaca_messages_2k_test
+ type: chat_template
+ chat_template: phi_3
+ field_messages: messages
+ message_field_role: role
+ message_field_content: content
+ roles:
+ user:
+ - user
+ assistant:
+ - assistant
+
+dataset_prepared_path:
+val_set_size: 0.05
+output_dir: ./outputs/lora-out
+
+sequence_len: 4096
+sample_packing: false
+pad_to_sequence_len: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 4
+num_epochs: 2
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bfloat16: true
+bf16: true
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+s2_attention:
+
+warmup_steps: 10
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 4
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index 19e36531a5..717367eefa 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -24,8 +24,8 @@ def __init__(
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
- message_field_training: str = "train",
- message_field_training_detail: str = "train_detail",
+ message_field_training: Optional[str] = None,
+ message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -186,7 +186,7 @@ def __init__(
train_on_inputs,
sequence_len,
roles_to_train=None,
- train_on_eos="last",
+ train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
@@ -201,6 +201,37 @@ def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt):
+ # Old simple legacy behavior that works reliably.
+ if (
+ not self.roles_to_train
+ and not self.train_on_eos
+ and not self.prompter.message_field_training
+ and not self.prompter.message_field_training_detail
+ ):
+ turns = self.get_conversation_thread(prompt)
+ prompt_ids = self.prompter.build_prompt(
+ turns[:-1], add_generation_prompt=True
+ )
+ input_ids = self.prompter.build_prompt(turns)
+
+ if not self.train_on_inputs:
+ user_prompt_len = len(prompt_ids)
+ labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
+ else:
+ labels = input_ids
+
+ tokenized_prompt = {
+ "input_ids": input_ids,
+ "labels": labels,
+ "attention_mask": [1] * len(input_ids),
+ }
+
+ return tokenized_prompt
+ LOG.info(self.roles_to_train)
+ LOG.info(self.train_on_eos)
+ LOG.info(self.prompter.message_field_training)
+ LOG.info(self.prompter.message_field_training_detail)
+
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
@@ -219,9 +250,11 @@ def tokenize_prompt(self, prompt):
should_train = (
train_turn
if train_turn is not None
- else bool(train_detail is not None)
- if train_detail is not None
- else self.train_on_inputs or role in self.roles_to_train
+ else (
+ bool(train_detail is not None)
+ if train_detail is not None
+ else self.train_on_inputs or role in self.roles_to_train
+ )
)
LOG.debug(f"Should train: {should_train}")
@@ -344,9 +377,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
- "message_field_training": ds_cfg.get("message_field_training", "training"),
+ "message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
- "message_field_training_detail", "train_detail"
+ "message_field_training_detail",
+ None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -357,8 +391,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
- "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
- "train_on_eos": ds_cfg.get("train_on_eos", "turn"),
+ "roles_to_train": ds_cfg.get("roles_to_train", []),
+ "train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = ChatTemplateStrategy(
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index 51f88b1bdf..7a96f5c1e1 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -26,6 +26,7 @@ def chat_templates(user_choice: str):
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
+ "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}",
"jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n',
}
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index 9044047cce..458bacdb12 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -189,6 +189,7 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
+ phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py
new file mode 100644
index 0000000000..43423f7255
--- /dev/null
+++ b/tests/prompt_strategies/conftest.py
@@ -0,0 +1,71 @@
+"""
+shared fixtures for prompt strategies tests
+"""
+
+import pytest
+from datasets import Dataset
+from transformers import AutoTokenizer
+
+
+@pytest.fixture(name="assistant_dataset")
+def fixture_assistant_dataset():
+ return Dataset.from_list(
+ [
+ {
+ "messages": [
+ {"role": "user", "content": "hello"},
+ {"role": "assistant", "content": "hello"},
+ {"role": "user", "content": "goodbye"},
+ {"role": "assistant", "content": "goodbye"},
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="sharegpt_dataset")
+def fixture_sharegpt_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "conversations": [
+ {"from": "human", "value": "hello"},
+ {"from": "gpt", "value": "hello"},
+ {"from": "human", "value": "goodbye"},
+ {"from": "gpt", "value": "goodbye"},
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="basic_dataset")
+def fixture_basic_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "conversations": [
+ {"from": "system", "value": "You are an AI assistant."},
+ {"from": "human", "value": "Hello"},
+ {"from": "assistant", "value": "Hi there!"},
+ {"from": "human", "value": "How are you?"},
+ {"from": "assistant", "value": "I'm doing well, thank you!"},
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="llama3_tokenizer")
+def fixture_llama3_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
+
+ return tokenizer
+
+
+@pytest.fixture(name="phi35_tokenizer")
+def fixture_phi35_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
+ return tokenizer
diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py
index e2fc0f6a52..28210b7ae8 100644
--- a/tests/prompt_strategies/test_chat_templates.py
+++ b/tests/prompt_strategies/test_chat_templates.py
@@ -5,10 +5,6 @@
import logging
import unittest
-import pytest
-from datasets import Dataset
-from transformers import AutoTokenizer
-
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
@@ -22,657 +18,6 @@
LOG = logging.getLogger("axolotl")
-@pytest.fixture(name="assistant_dataset")
-def fixture_assistant_dataset():
- return Dataset.from_list(
- [
- {
- "messages": [
- {"role": "user", "content": "hello"},
- {"role": "assistant", "content": "hello"},
- {"role": "user", "content": "goodbye"},
- {"role": "assistant", "content": "goodbye"},
- ]
- }
- ]
- )
-
-
-@pytest.fixture(name="sharegpt_dataset")
-def fixture_sharegpt_dataset():
- # pylint: disable=duplicate-code
- return Dataset.from_list(
- [
- {
- "conversations": [
- {"from": "human", "value": "hello"},
- {"from": "gpt", "value": "hello"},
- {"from": "human", "value": "goodbye"},
- {"from": "gpt", "value": "goodbye"},
- ]
- }
- ]
- )
-
-
-@pytest.fixture(name="basic_dataset")
-def fixture_basic_dataset():
- # pylint: disable=duplicate-code
- return Dataset.from_list(
- [
- {
- "conversations": [
- {"from": "system", "value": "You are an AI assistant."},
- {"from": "human", "value": "Hello"},
- {"from": "assistant", "value": "Hi there!"},
- {"from": "human", "value": "How are you?"},
- {"from": "assistant", "value": "I'm doing well, thank you!"},
- ]
- }
- ]
- )
-
-
-@pytest.fixture(name="llama3_tokenizer")
-def fixture_llama3_tokenizer():
- tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
-
- return tokenizer
-
-
-class TestChatTemplateConfigurations:
- """
- Test class for various configurations of ChatTemplateStrategy.
- """
-
- @staticmethod
- def find_sublist(full_list, sub_list):
- token_count = len(sub_list)
- for index in range(len(full_list) - token_count + 1):
- if full_list[index : index + token_count] == sub_list:
- return index
- return -1
-
- def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_inputs=True")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=True,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- # Check the behavior of human inputs
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- labeled = all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(input_ids)]
- )
- LOG.debug(
- f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
- )
-
- LOG.debug("Full labels: %s", labels)
- LOG.debug("Full input_ids: %s", input_ids)
-
- def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_inputs=False")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that only assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- # Verify that human inputs are not labeled
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- LOG.debug(
- f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
- assert all(
- label == IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(input_ids)]
- ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
-
- def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing roles_to_train with assistant only")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that only assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing roles_to_train with all roles")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=True,
- sequence_len=512,
- roles_to_train=["human", "assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Verify that all responses are labeled (except for special tokens)
- all_responses = [
- "Hello",
- "Hi there!",
- "How are you?",
- "I'm doing well, thank you!",
- ]
- for response in all_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with empty roles_to_train")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=[],
- train_on_eos="none", # Add this line
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
-
- # Verify that no labels are set when roles_to_train is empty
- LOG.debug("Full labels: %s", labels)
- assert all(
- label == IGNORE_TOKEN_ID for label in labels
- ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
-
- def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='all'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="all",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- eos_indices = [
- i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
- ]
-
- assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
- for eos_idx in eos_indices:
- assert (
- labels[eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected EOS token at index {eos_idx} to be labeled"
-
- def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='turn'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="turn",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
-
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
-
- eos_idx = start_idx + len(response_ids)
- while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
- eos_idx += 1
-
- assert eos_idx < len(
- input_ids
- ), f"Could not find EOS token after '{response}'"
- assert (
- labels[eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected EOS token after assistant response '{response}' to be labeled"
-
- # Check that EOS tokens after human inputs are not labeled
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
-
- eos_idx = start_idx + len(input_ids)
- while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
- eos_idx += 1
-
- assert (
- labels[eos_idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token after human input '{input_text}' to not be labeled"
-
- def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='last'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="last",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- eos_indices = [
- i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
- ]
-
- assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
- last_eos_idx = eos_indices[-1]
-
- # Check that only the last EOS token is labeled
- for idx in eos_indices[:-1]:
- assert (
- labels[idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token at index {idx} to not be labeled"
- assert (
- labels[last_eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected last EOS token at index {last_eos_idx} to be labeled"
-
- def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_eos='none'")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- train_on_eos="none",
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- eos_token_id = llama3_tokenizer.eos_token_id
- eos_indices = [
- i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
- ]
-
- assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
- for eos_idx in eos_indices:
- assert (
- labels[eos_idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token at index {eos_idx} to not be labeled"
-
- def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with drop_system_message=True")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(
- llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
- ),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["assistant"],
- )
- res = strategy.tokenize_prompt(basic_dataset[0])
- input_ids = res["input_ids"]
-
- # Check if system message is not present in input_ids
- system_message = "You are an AI assistant."
- system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
- assert (
- self.find_sublist(input_ids, system_ids) == -1
- ), "Expected system message to be dropped"
-
- def test_custom_roles(self, llama3_tokenizer):
- LOG.info("Testing with custom roles mapping")
- custom_roles = {
- "user": ["human", "user"],
- "assistant": ["ai", "assistant"],
- "system": ["context"],
- }
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(
- llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
- ),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=["ai"],
- )
-
- # Create a new dataset with modified role names
- modified_conversations = [
- {"from": "context", "value": "You are an AI assistant."},
- {"from": "human", "value": "Hello"},
- {"from": "ai", "value": "Hi there!"},
- {"from": "human", "value": "How are you?"},
- {"from": "ai", "value": "I'm doing well, thank you!"},
- ]
-
- modified_dataset = Dataset.from_dict(
- {"conversations": [modified_conversations]}
- )
-
- res = strategy.tokenize_prompt(modified_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Check if AI responses are labeled correctly
- ai_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in ai_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- assert start_idx != -1, f"Could not find response '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for AI response '{response}' to be set"
-
- # Check if human messages are not labeled
- human_messages = ["Hello", "How are you?"]
- for message in human_messages:
- message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, message_ids)
- assert start_idx != -1, f"Could not find message '{message}' in input_ids"
- assert all(
- label == IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(message_ids)]
- ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
-
- def test_message_field_training(self, llama3_tokenizer):
- LOG.info("Testing with message_field_training")
- strategy = ChatTemplateStrategy(
- ChatTemplatePrompter(
- llama3_tokenizer,
- chat_templates("llama3"),
- message_field_training="train",
- message_field_training_detail="train_detail",
- ),
- tokenizer=llama3_tokenizer,
- train_on_inputs=False,
- sequence_len=512,
- roles_to_train=[],
- )
-
- # Create a new dataset with the train and train_detail fields
- modified_conversation = [
- {"from": "system", "value": "You are an AI assistant.", "train": False},
- {"from": "human", "value": "Hello", "train": False},
- {"from": "assistant", "value": "Hello", "train": True},
- {"from": "human", "value": "How are you?", "train": True},
- {
- "from": "assistant",
- "value": "I'm doing very well, thank you!",
- "train_detail": [
- {"begin_offset": 0, "end_offset": 8, "train": False},
- {"begin_offset": 9, "end_offset": 18, "train": True},
- {"begin_offset": 19, "end_offset": 30, "train": False},
- ],
- },
- {
- "from": "human",
- "value": "I'm doing very well, thank you!",
- "train": False,
- },
- {"from": "assistant", "value": "Hi there!", "train": True},
- ]
-
- modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
-
- res = strategy.tokenize_prompt(modified_dataset[0])
- labels = res["labels"]
- input_ids = res["input_ids"]
-
- # Function to find all occurrences of a sublist
- def find_all_sublists(full_list, sub_list):
- indices = []
- for index in range(len(full_list) - len(sub_list) + 1):
- if full_list[index : index + len(sub_list)] == sub_list:
- indices.append(index)
- return indices
-
- # Keep track of which occurrences we've processed
- processed_occurrences = {}
- # Check if messages are labeled correctly based on train or train_detail
- for i, turn in enumerate(modified_conversation):
- turn_tokens = llama3_tokenizer.encode(
- turn["value"], add_special_tokens=False
- )
- occurrences = find_all_sublists(input_ids, turn_tokens)
- turn_key = turn["value"]
- if turn_key not in processed_occurrences:
- processed_occurrences[turn_key] = 0
- current_occurrence = processed_occurrences[turn_key]
-
- if current_occurrence >= len(occurrences):
- assert (
- False
- ), f"Not enough occurrences found for message: {turn['value']}"
-
- start_idx = occurrences[current_occurrence]
- processed_occurrences[turn_key] += 1
- end_idx = start_idx + len(turn_tokens)
-
- LOG.debug(
- f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
- )
-
- if "train_detail" in turn:
- # Get token offsets
- tokenized_output = llama3_tokenizer(
- turn["value"], return_offsets_mapping=True, add_special_tokens=False
- )
- token_offsets = tokenized_output["offset_mapping"]
-
- # Adjust token offsets as done in the implementation
- for i in range(len(token_offsets) - 1):
- token_offsets[i] = (
- token_offsets[i][0],
- token_offsets[i + 1][0] - 1,
- )
- token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
-
- # Adjust train_details
- adjusted_train_details = strategy.prompter.adjust_train_details(
- turn["train_detail"], token_offsets
- )
-
- LOG.debug(f"Original train_details: {turn['train_detail']}")
- LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
-
- # Handle train_detail
- token_offsets = strategy.prompter.get_offsets_for_train_detail(
- text=turn["value"],
- train_details=adjusted_train_details,
- mask_untrainable=False,
- )
- token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
- text=turn["value"],
- train_details=adjusted_train_details,
- mask_untrainable=True,
- )
- LOG.debug(f"Token offsets: {token_offsets_masked}")
-
- expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
- for i, offset in enumerate(token_offsets_masked):
- if offset != IGNORE_TOKEN_ID:
- expected_labels[i] = turn_tokens[i]
- actual_labels = labels[
- start_idx : start_idx + len(token_offsets_masked)
- ]
- assert (
- actual_labels == expected_labels
- ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
-
- for detail in adjusted_train_details:
- # Find the token indices that correspond to the character offsets
- detail_start = start_idx + next(
- i
- for i, offset in enumerate(token_offsets)
- if offset >= detail["begin_offset"]
- )
- detail_end = start_idx + next(
- (
- i
- for i, offset in enumerate(token_offsets)
- if offset > detail["end_offset"]
- ),
- len(token_offsets),
- )
-
- detail_text = turn["value"][
- detail["begin_offset"] : detail["end_offset"] + 1
- ]
- detail_labels = labels[detail_start:detail_end]
- detail_input_ids = input_ids[detail_start:detail_end]
-
- LOG.debug(
- f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
- )
- LOG.debug(f"Detail input_ids: {detail_input_ids}")
- LOG.debug(f"Detail labels: {detail_labels}")
- LOG.debug(
- f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
- )
- LOG.debug(
- f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
- )
-
- if detail["train"]:
- assert all(
- label != IGNORE_TOKEN_ID for label in detail_labels
- ), (
- f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
- f"Labels({detail_start}:{detail_end}): {detail_labels}, "
- f"InputIDs: {detail_input_ids}, "
- f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
- )
- else:
- assert all(
- label == IGNORE_TOKEN_ID for label in detail_labels
- ), (
- f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
- f"Labels({detail_start}:{detail_end}): {detail_labels}, "
- f"InputIDs: {detail_input_ids}, "
- f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
- )
- else:
- should_train = turn.get("train", False)
- turn_labels = labels[start_idx:end_idx]
-
- LOG.debug(f"Should train: {should_train}")
- LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
- LOG.debug(f"Turn labels: {turn_labels}")
- LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
- LOG.debug(
- f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
- )
-
- if should_train:
- assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
- f"Expected all labels for '{turn['value']}' to be set\n"
- f"Labels({start_idx}:{end_idx}): {turn_labels}, "
- f"InputIDs: {input_ids[start_idx:end_idx]}, "
- f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
- )
- else:
- assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
- f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
- f"Labels({start_idx}:{end_idx}): {turn_labels}, "
- f"InputIDs: {input_ids[start_idx:end_idx]}, "
- f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
- )
-
- LOG.debug(
- f"Processed turn: {turn['from']}, content: '{turn['value']}', "
- f"start_idx: {start_idx}, end_idx: {end_idx}, "
- f"labels: {labels[start_idx:end_idx]}"
- )
-
- LOG.debug(f"Final labels: {labels}")
- LOG.debug(f"Final input_ids: {input_ids}")
-
-
class TestAssistantChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -740,7 +85,6 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset):
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
- roles_to_train=["assistant"],
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
@@ -764,6 +108,64 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset):
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+ def test_phi35(self, phi35_tokenizer, assistant_dataset):
+ LOG.info("Testing phi-3.5 with assistant dataset")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ phi35_tokenizer,
+ chat_templates("phi_35"),
+ message_field_role="role",
+ message_field_content="content",
+ roles={
+ "user": ["user"],
+ "assistant": ["assistant"],
+ "system": ["system"],
+ },
+ ),
+ tokenizer=phi35_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ )
+ strategy.messages = "messages"
+ res = strategy.tokenize_prompt(assistant_dataset[0])
+ input_ids = res["input_ids"]
+ labels = res["labels"]
+ # fmt: off
+ expected_input_ids = [
+ 32010, # user
+ 22172, 32007, # user eot
+ 32001, # assistant
+ 22172, 32007, # assistant eot
+ 32010, # user
+ 1781, 26966, 32007, # user eot
+ 32001, # assistant
+ 1781, 26966, 32007, # assistant eot
+ 32000, # eos
+ ]
+ expected_labels = [
+ -100, # user
+ -100, -100, # user eot
+ -100, # assistant
+ -100, -100, # assistant eot,
+ -100, # user
+ -100, -100, -100, # user eot
+ -100, # assistant
+ 1781, 26966, 32007, # assistant eot
+ 32000, # eos
+ ]
+ # fmt: on
+ LOG.debug(f"Expected input_ids: {expected_input_ids}")
+ LOG.debug(f"Actual input_ids: {input_ids}")
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+
+ LOG.debug(f"Expected labels : {expected_labels}")
+ LOG.debug(f"Actual labels : {labels}")
+ assert (
+ labels == expected_labels
+ ), f"Input IDs mismatch: {labels} != {expected_labels}"
+
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
strategy = ChatTemplateStrategy(
diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py
new file mode 100644
index 0000000000..f18fb39423
--- /dev/null
+++ b/tests/prompt_strategies/test_chat_templates_advanced.py
@@ -0,0 +1,615 @@
+"""
+tests for chat_template prompt strategy
+"""
+
+import logging
+import unittest
+
+from datasets import Dataset
+
+from axolotl.prompt_strategies.chat_template import (
+ ChatTemplatePrompter,
+ ChatTemplateStrategy,
+)
+from axolotl.prompters import IGNORE_TOKEN_ID
+from axolotl.utils.chat_templates import chat_templates
+
+logging.basicConfig(level=logging.DEBUG)
+LOG = logging.getLogger("axolotl")
+
+
+class TestChatTemplateConfigurations:
+ """
+ Test class for various configurations of ChatTemplateStrategy.
+ """
+
+ @staticmethod
+ def find_sublist(full_list, sub_list):
+ token_count = len(sub_list)
+ for index in range(len(full_list) - token_count + 1):
+ if full_list[index : index + token_count] == sub_list:
+ return index
+ return -1
+
+ def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_inputs=True")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=True,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that assistant responses are labeled
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert start_idx != -1, f"Could not find '{response}' in input_ids"
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ # Check the behavior of human inputs
+ human_inputs = ["Hello", "How are you?"]
+ for input_text in human_inputs:
+ input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, input_ids)
+ labeled = all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(input_ids)]
+ )
+ LOG.debug(
+ f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
+ )
+
+ LOG.debug("Full labels: %s", labels)
+ LOG.debug("Full input_ids: %s", input_ids)
+
+ def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_inputs=False")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that only assistant responses are labeled
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert start_idx != -1, f"Could not find '{response}' in input_ids"
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ # Verify that human inputs are not labeled
+ human_inputs = ["Hello", "How are you?"]
+ for input_text in human_inputs:
+ input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, input_ids)
+ LOG.debug(
+ f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
+ )
+ assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
+ assert all(
+ label == IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(input_ids)]
+ ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
+
+ def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing roles_to_train with assistant only")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that only assistant responses are labeled
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing roles_to_train with all roles")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=True,
+ sequence_len=512,
+ roles_to_train=["human", "assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Verify that all responses are labeled (except for special tokens)
+ all_responses = [
+ "Hello",
+ "Hi there!",
+ "How are you?",
+ "I'm doing well, thank you!",
+ ]
+ for response in all_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ LOG.debug(
+ f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ )
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+
+ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with empty roles_to_train")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=[],
+ train_on_eos="none", # Add this line
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+
+ # Verify that no labels are set when roles_to_train is empty
+ LOG.debug("Full labels: %s", labels)
+ assert all(
+ label == IGNORE_TOKEN_ID for label in labels
+ ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
+
+ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='all'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="all",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ eos_indices = [
+ i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
+ ]
+
+ assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
+ for eos_idx in eos_indices:
+ assert (
+ labels[eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected EOS token at index {eos_idx} to be labeled"
+
+ def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='turn'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="turn",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+
+ for response in assistant_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ assert start_idx != -1, f"Could not find '{response}' in input_ids"
+
+ eos_idx = start_idx + len(response_ids)
+ while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
+ eos_idx += 1
+
+ assert eos_idx < len(
+ input_ids
+ ), f"Could not find EOS token after '{response}'"
+ assert (
+ labels[eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected EOS token after assistant response '{response}' to be labeled"
+
+ # Check that EOS tokens after human inputs are not labeled
+ human_inputs = ["Hello", "How are you?"]
+ for input_text in human_inputs:
+ input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, input_ids)
+ assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
+
+ eos_idx = start_idx + len(input_ids)
+ while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
+ eos_idx += 1
+
+ assert (
+ labels[eos_idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token after human input '{input_text}' to not be labeled"
+
+ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='last'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="last",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ eos_indices = [
+ i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
+ ]
+
+ assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
+ last_eos_idx = eos_indices[-1]
+
+ # Check that only the last EOS token is labeled
+ for idx in eos_indices[:-1]:
+ assert (
+ labels[idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token at index {idx} to not be labeled"
+ assert (
+ labels[last_eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected last EOS token at index {last_eos_idx} to be labeled"
+
+ def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with train_on_eos='none'")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ train_on_eos="none",
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ eos_token_id = llama3_tokenizer.eos_token_id
+ eos_indices = [
+ i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
+ ]
+
+ assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
+ for eos_idx in eos_indices:
+ assert (
+ labels[eos_idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token at index {eos_idx} to not be labeled"
+
+ def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
+ LOG.info("Testing with drop_system_message=True")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ input_ids = res["input_ids"]
+
+ # Check if system message is not present in input_ids
+ system_message = "You are an AI assistant."
+ system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
+ assert (
+ self.find_sublist(input_ids, system_ids) == -1
+ ), "Expected system message to be dropped"
+
+ def test_custom_roles(self, llama3_tokenizer):
+ LOG.info("Testing with custom roles mapping")
+ custom_roles = {
+ "user": ["human", "user"],
+ "assistant": ["ai", "assistant"],
+ "system": ["context"],
+ }
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=["ai"],
+ )
+
+ # Create a new dataset with modified role names
+ modified_conversations = [
+ {"from": "context", "value": "You are an AI assistant."},
+ {"from": "human", "value": "Hello"},
+ {"from": "ai", "value": "Hi there!"},
+ {"from": "human", "value": "How are you?"},
+ {"from": "ai", "value": "I'm doing well, thank you!"},
+ ]
+
+ modified_dataset = Dataset.from_dict(
+ {"conversations": [modified_conversations]}
+ )
+
+ res = strategy.tokenize_prompt(modified_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Check if AI responses are labeled correctly
+ ai_responses = ["Hi there!", "I'm doing well, thank you!"]
+ for response in ai_responses:
+ response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, response_ids)
+ assert start_idx != -1, f"Could not find response '{response}' in input_ids"
+ assert all(
+ label != IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(response_ids)]
+ ), f"Expected labels for AI response '{response}' to be set"
+
+ # Check if human messages are not labeled
+ human_messages = ["Hello", "How are you?"]
+ for message in human_messages:
+ message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
+ start_idx = self.find_sublist(input_ids, message_ids)
+ assert start_idx != -1, f"Could not find message '{message}' in input_ids"
+ assert all(
+ label == IGNORE_TOKEN_ID
+ for label in labels[start_idx : start_idx + len(message_ids)]
+ ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
+
+ def test_message_field_training(self, llama3_tokenizer):
+ LOG.info("Testing with message_field_training")
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer,
+ chat_templates("llama3"),
+ message_field_training="train",
+ message_field_training_detail="train_detail",
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ sequence_len=512,
+ roles_to_train=[],
+ )
+
+ # Create a new dataset with the train and train_detail fields
+ modified_conversation = [
+ {"from": "system", "value": "You are an AI assistant.", "train": False},
+ {"from": "human", "value": "Hello", "train": False},
+ {"from": "assistant", "value": "Hello", "train": True},
+ {"from": "human", "value": "How are you?", "train": True},
+ {
+ "from": "assistant",
+ "value": "I'm doing very well, thank you!",
+ "train_detail": [
+ {"begin_offset": 0, "end_offset": 8, "train": False},
+ {"begin_offset": 9, "end_offset": 18, "train": True},
+ {"begin_offset": 19, "end_offset": 30, "train": False},
+ ],
+ },
+ {
+ "from": "human",
+ "value": "I'm doing very well, thank you!",
+ "train": False,
+ },
+ {"from": "assistant", "value": "Hi there!", "train": True},
+ ]
+
+ modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
+
+ res = strategy.tokenize_prompt(modified_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Function to find all occurrences of a sublist
+ def find_all_sublists(full_list, sub_list):
+ indices = []
+ for index in range(len(full_list) - len(sub_list) + 1):
+ if full_list[index : index + len(sub_list)] == sub_list:
+ indices.append(index)
+ return indices
+
+ # Keep track of which occurrences we've processed
+ processed_occurrences = {}
+ # Check if messages are labeled correctly based on train or train_detail
+ for i, turn in enumerate(modified_conversation):
+ turn_tokens = llama3_tokenizer.encode(
+ turn["value"], add_special_tokens=False
+ )
+ occurrences = find_all_sublists(input_ids, turn_tokens)
+ turn_key = turn["value"]
+ if turn_key not in processed_occurrences:
+ processed_occurrences[turn_key] = 0
+ current_occurrence = processed_occurrences[turn_key]
+
+ if current_occurrence >= len(occurrences):
+ assert (
+ False
+ ), f"Not enough occurrences found for message: {turn['value']}"
+
+ start_idx = occurrences[current_occurrence]
+ processed_occurrences[turn_key] += 1
+ end_idx = start_idx + len(turn_tokens)
+
+ LOG.debug(
+ f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
+ )
+
+ if "train_detail" in turn:
+ # Get token offsets
+ tokenized_output = llama3_tokenizer(
+ turn["value"], return_offsets_mapping=True, add_special_tokens=False
+ )
+ token_offsets = tokenized_output["offset_mapping"]
+
+ # Adjust token offsets as done in the implementation
+ for i in range(len(token_offsets) - 1):
+ token_offsets[i] = (
+ token_offsets[i][0],
+ token_offsets[i + 1][0] - 1,
+ )
+ token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
+
+ # Adjust train_details
+ adjusted_train_details = strategy.prompter.adjust_train_details(
+ turn["train_detail"], token_offsets
+ )
+
+ LOG.debug(f"Original train_details: {turn['train_detail']}")
+ LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
+
+ # Handle train_detail
+ token_offsets = strategy.prompter.get_offsets_for_train_detail(
+ text=turn["value"],
+ train_details=adjusted_train_details,
+ mask_untrainable=False,
+ )
+ token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
+ text=turn["value"],
+ train_details=adjusted_train_details,
+ mask_untrainable=True,
+ )
+ LOG.debug(f"Token offsets: {token_offsets_masked}")
+
+ expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
+ for i, offset in enumerate(token_offsets_masked):
+ if offset != IGNORE_TOKEN_ID:
+ expected_labels[i] = turn_tokens[i]
+ actual_labels = labels[
+ start_idx : start_idx + len(token_offsets_masked)
+ ]
+ assert (
+ actual_labels == expected_labels
+ ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
+
+ for detail in adjusted_train_details:
+ # Find the token indices that correspond to the character offsets
+ detail_start = start_idx + next(
+ i
+ for i, offset in enumerate(token_offsets)
+ if offset >= detail["begin_offset"]
+ )
+ detail_end = start_idx + next(
+ (
+ i
+ for i, offset in enumerate(token_offsets)
+ if offset > detail["end_offset"]
+ ),
+ len(token_offsets),
+ )
+
+ detail_text = turn["value"][
+ detail["begin_offset"] : detail["end_offset"] + 1
+ ]
+ detail_labels = labels[detail_start:detail_end]
+ detail_input_ids = input_ids[detail_start:detail_end]
+
+ LOG.debug(
+ f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
+ )
+ LOG.debug(f"Detail input_ids: {detail_input_ids}")
+ LOG.debug(f"Detail labels: {detail_labels}")
+ LOG.debug(
+ f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
+ )
+ LOG.debug(
+ f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
+ )
+
+ if detail["train"]:
+ assert all(
+ label != IGNORE_TOKEN_ID for label in detail_labels
+ ), (
+ f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
+ f"Labels({detail_start}:{detail_end}): {detail_labels}, "
+ f"InputIDs: {detail_input_ids}, "
+ f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
+ )
+ else:
+ assert all(
+ label == IGNORE_TOKEN_ID for label in detail_labels
+ ), (
+ f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
+ f"Labels({detail_start}:{detail_end}): {detail_labels}, "
+ f"InputIDs: {detail_input_ids}, "
+ f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
+ )
+ else:
+ should_train = turn.get("train", False)
+ turn_labels = labels[start_idx:end_idx]
+
+ LOG.debug(f"Should train: {should_train}")
+ LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
+ LOG.debug(f"Turn labels: {turn_labels}")
+ LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
+ LOG.debug(
+ f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
+ )
+
+ if should_train:
+ assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
+ f"Expected all labels for '{turn['value']}' to be set\n"
+ f"Labels({start_idx}:{end_idx}): {turn_labels}, "
+ f"InputIDs: {input_ids[start_idx:end_idx]}, "
+ f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
+ )
+ else:
+ assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
+ f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
+ f"Labels({start_idx}:{end_idx}): {turn_labels}, "
+ f"InputIDs: {input_ids[start_idx:end_idx]}, "
+ f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
+ )
+
+ LOG.debug(
+ f"Processed turn: {turn['from']}, content: '{turn['value']}', "
+ f"start_idx: {start_idx}, end_idx: {end_idx}, "
+ f"labels: {labels[start_idx:end_idx]}"
+ )
+
+ LOG.debug(f"Final labels: {labels}")
+ LOG.debug(f"Final input_ids: {input_ids}")
+
+
+if __name__ == "__main__":
+ unittest.main()