diff --git a/src/alignment/data.py b/src/alignment/data.py index 84544c68..56a4af62 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -32,7 +32,7 @@ def maybe_insert_system_message(messages, tokenizer): # chat template can be one of two attributes, we check in order chat_template = tokenizer.chat_template if chat_template is None: - chat_template = tokenizer.default_chat_template + chat_template = tokenizer.get_chat_template() # confirm the jinja template refers to a system message before inserting if "system" in chat_template or "<|im_start|>" in chat_template: diff --git a/tests/test_data.py b/tests/test_data.py index 28483c36..bcf600f5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -122,21 +122,21 @@ def setUp(self): ) def test_maybe_insert_system_message(self): - # does not accept system prompt - mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") - # accepts system prompt. use codellama since it has no HF token requirement - llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") + # Chat template that does not accept system prompt. Use community checkpoint since it has no HF token requirement + tokenizer_sys_excl = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3") + # Chat template that accepts system prompt + tokenizer_sys_incl = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}] messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}] - mistral_messages = deepcopy(messages_sys_excl) - llama_messages = deepcopy(messages_sys_excl) - maybe_insert_system_message(mistral_messages, mistral_tokenizer) - maybe_insert_system_message(llama_messages, llama_tokenizer) + messages_proc_excl = deepcopy(messages_sys_excl) + message_proc_incl = deepcopy(messages_sys_excl) + maybe_insert_system_message(messages_proc_excl, tokenizer_sys_excl) + maybe_insert_system_message(message_proc_incl, tokenizer_sys_incl) # output from mistral should not have a system message, output from llama should - self.assertEqual(mistral_messages, messages_sys_excl) - self.assertEqual(llama_messages, messages_sys_incl) + self.assertEqual(messages_proc_excl, messages_sys_excl) + self.assertEqual(message_proc_incl, messages_sys_incl) def test_sft(self): dataset = self.dataset.map( diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index 07bb6813..e0fc6fe2 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -15,7 +15,6 @@ import unittest import torch -from transformers import AutoTokenizer from alignment import ( DataArguments, @@ -64,19 +63,6 @@ def test_default_chat_template(self): tokenizer = get_tokenizer(self.model_args, DataArguments()) self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) - def test_default_chat_template_no_overwrite(self): - """ - If no chat template is passed explicitly in the config, then for models with a - `default_chat_template` but no `chat_template` we do not set a `chat_template`, - and that we do not change `default_chat_template` - """ - model_args = ModelArguments(model_name_or_path="m-a-p/OpenCodeInterpreter-SC2-7B") - base_tokenizer = AutoTokenizer.from_pretrained("m-a-p/OpenCodeInterpreter-SC2-7B") - processed_tokenizer = get_tokenizer(model_args, DataArguments()) - - assert getattr(processed_tokenizer, "chat_template") is None - self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template) - def test_chatml_chat_template(self): chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))