diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index e1065a950f..068261da36 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.LLAMA2: + if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.system_message: if self.messages: # For llama, the system message is incorporated into the first human instruction @@ -101,6 +101,28 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": + contains_sys_msg = False + if self.system_message: + contains_sys_msg = True + if self.messages: + # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt = self.system_template.format( + system_message=" " + self.system_message + ) + system_prompt += first_msg + self.messages.pop(0) + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message and i == 0 and not contains_sys_msg: + yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction. + elif message: + yield role + " ", message + else: + yield role, "" + return if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 6e57ffb370..cea39d0adf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -2,6 +2,7 @@ import json import logging import unittest +from copy import deepcopy from pathlib import Path from typing import Optional @@ -25,6 +26,50 @@ LOG = logging.getLogger("axolotl") +test_data = { + "multi_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, + "single_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "single_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "multi_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, +} + + +def prompt_strat(conversation, tokenizer): + "Helper function to create a prompt strategy for testing." + prompter = ShareGPTPrompterV2(conversation=conversation) + return ShareGPTPromptTokenizingStrategy( + prompter, + tokenizer, + False, + 2048, + ) + class TestPromptTokenizationStrategies(unittest.TestCase): """ @@ -116,74 +161,68 @@ def test_sharegpt_warnings_turns(self): def test_sharegpt_llama(self): "Make sure the sharegpt/llama is tokenized and formatted correctly." - prompter = ShareGPTPrompterV2(conversation="llama-2") - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) + strat = prompt_strat("llama-2", self.tokenizer) def tokenize(conv): - return strat.tokenize_prompt(conv)["input_ids"] + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] def decode(ids): return strat.tokenizer.decode(ids) - # Multi-turn conversations - multi_turn_conv = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - } # fmt: off - mt_ids = tokenize(multi_turn_conv) + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # Single-turn conversations - single_turn_conv = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - } - - st_ids = tokenize(single_turn_conv) + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] # No system message, single-turn - no_sys_conv = { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - } - - ns_ids = tokenize(no_sys_conv) + ns_ids = tokenize(test_data['single_turn_no_sys']) assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] # No system message, multi-turn - no_sys_mt_conv = { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - } - ns_mt_ids = tokenize(no_sys_mt_conv) + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] # fmt: on + def test_sharegpt_mistral(self): + "Make sure the sharegpt/mistral is tokenized and formatted correctly." + strat = prompt_strat("mistral", self.tokenizer) + + def tokenize(conv): + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # fmt: off + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) + assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) + assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + ns_ids = tokenize(test_data['single_turn_no_sys']) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + def test_sharegpt_changes_roles(self): conversation = { "roles": ["USER", "CHARACTER"],