From 9b8864a0387f6d5d200a084b6d98cdc418c4acc3 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Tue, 19 Dec 2023 17:12:45 -0800 Subject: [PATCH 1/3] fix mistral prompts --- .../fastchat_conversation_turns.py | 21 ++- tests/test_prompt_tokenizers.py | 131 ++++++++++++------ 2 files changed, 105 insertions(+), 47 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index e1065a950f..b35ea05ee6 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.LLAMA2: + if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.system_message: if self.messages: # For llama, the system message is incorporated into the first human instruction @@ -101,6 +101,25 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": + contains_sys_msg = False + if self.system_message: + contains_sys_msg = True + if self.messages: + # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt += first_msg + self.messages.pop(0) + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message and i == 0 and not contains_sys_msg: + yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction. + elif message: + yield role + " ", message + else: + yield role, "" + return if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 6e57ffb370..ffc3a65bef 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -2,6 +2,7 @@ import json import logging import unittest +from copy import deepcopy from pathlib import Path from typing import Optional @@ -25,6 +26,50 @@ LOG = logging.getLogger("axolotl") +test_data = { + "multi_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, + "single_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "single_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "multi_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, +} + + +def prompt_strat(conversation, tokenizer): + "Helper function to create a prompt strategy for testing." + prompter = ShareGPTPrompterV2(conversation=conversation) + return ShareGPTPromptTokenizingStrategy( + prompter, + tokenizer, + False, + 2048, + ) + class TestPromptTokenizationStrategies(unittest.TestCase): """ @@ -116,74 +161,68 @@ def test_sharegpt_warnings_turns(self): def test_sharegpt_llama(self): "Make sure the sharegpt/llama is tokenized and formatted correctly." - prompter = ShareGPTPrompterV2(conversation="llama-2") - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) + strat = prompt_strat("llama-2", self.tokenizer) def tokenize(conv): - return strat.tokenize_prompt(conv)["input_ids"] + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] def decode(ids): return strat.tokenizer.decode(ids) - # Multi-turn conversations - multi_turn_conv = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - } # fmt: off - mt_ids = tokenize(multi_turn_conv) + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # Single-turn conversations - single_turn_conv = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - } - - st_ids = tokenize(single_turn_conv) + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] # No system message, single-turn - no_sys_conv = { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - } - - ns_ids = tokenize(no_sys_conv) + ns_ids = tokenize(test_data['single_turn_no_sys']) assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] # No system message, multi-turn - no_sys_mt_conv = { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - } - ns_mt_ids = tokenize(no_sys_mt_conv) + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] # fmt: on + def test_sharegpt_mistral(self): + "Make sure the sharegpt/mistral is tokenized and formatted correctly." + strat = prompt_strat("mistral", self.tokenizer) + + def tokenize(conv): + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # fmt: off + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) + assert decode(mt_ids) == ' [INST]lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 29880, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) + assert decode(st_ids) == ' [INST]lorem\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 29880, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + ns_ids = tokenize(test_data['single_turn_no_sys']) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + def test_sharegpt_changes_roles(self): conversation = { "roles": ["USER", "CHARACTER"], From ee3ff48d64873481d0efd95cec9458e72a02c706 Mon Sep 17 00:00:00 2001 From: hamelsmu Date: Tue, 19 Dec 2023 17:55:38 -0800 Subject: [PATCH 2/3] fix spacing --- .../fastchat_conversation_turns.py | 33 ++++++++++--------- tests/test_prompt_tokenizers.py | 8 ++--- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index b35ea05ee6..782d41d8ff 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -30,7 +30,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.ADD_COLON_TWO: + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: seps = [self.sep, self.sep2] yield "", system_prompt + seps[0] for i, (role, message) in enumerate(self.messages): @@ -39,7 +39,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: yield "", system_prompt + self.sep for role, message in self.messages: if message: @@ -47,7 +47,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ": ", "" # must be end with a space return - if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: yield "", "" if system_prompt == "" else system_prompt + self.sep for role, message in self.messages: if message: @@ -55,7 +55,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + "\n", "" return - if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: yield "", system_prompt for role, message in self.messages: if message: @@ -63,7 +63,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - if self.sep_style == SeparatorStyle.NO_COLON_TWO: + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: seps = [self.sep, self.sep2] yield "", system_prompt for i, (role, message) in enumerate(self.messages): @@ -72,7 +72,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - if self.sep_style == SeparatorStyle.RWKV: + elif self.sep_style == SeparatorStyle.RWKV: yield "", system_prompt for i, (role, message) in enumerate(self.messages): if message: @@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": + elif self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.system_message: if self.messages: # For llama, the system message is incorporated into the first human instruction @@ -101,7 +101,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": + elif self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": contains_sys_msg = False if self.system_message: contains_sys_msg = True @@ -109,6 +109,9 @@ def get_turns( # pylint: disable=too-many-return-statements # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline first_role, first_msg = self.messages[0] if first_role == self.roles[0]: + system_prompt = self.system_template.format( + system_message=" " + self.system_message + ) system_prompt += first_msg self.messages.pop(0) yield "", system_prompt @@ -120,7 +123,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - if self.sep_style == SeparatorStyle.CHATGLM: + elif self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 round_add_n = 1 if self.name == "chatglm2" else 0 @@ -136,7 +139,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield f"{role}:", "" return - if self.sep_style == SeparatorStyle.CHATML: + elif self.sep_style == SeparatorStyle.CHATML: yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" for role, message in self.messages: if message: @@ -144,7 +147,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + "\n", "" return - if self.sep_style == SeparatorStyle.CHATINTERN: + elif self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] yield "", system_prompt @@ -155,7 +158,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.DOLLY: + elif self.sep_style == SeparatorStyle.DOLLY: seps = [self.sep, self.sep2] yield "", system_prompt for i, (role, message) in enumerate(self.messages): @@ -165,7 +168,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":\n", "" return - if self.sep_style == SeparatorStyle.PHOENIX: + elif self.sep_style == SeparatorStyle.PHOENIX: yield "", system_prompt for role, message in self.messages: if message: @@ -173,7 +176,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ": " + "", "" return - if self.sep_style == SeparatorStyle.ROBIN: + elif self.sep_style == SeparatorStyle.ROBIN: yield "", system_prompt + self.sep for role, message in self.messages: if message: @@ -181,7 +184,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":\n", "" return - if self.sep_style == SeparatorStyle.FALCON_CHAT: + elif self.sep_style == SeparatorStyle.FALCON_CHAT: if self.system_message: yield "", system_prompt + self.sep for role, message in self.messages: diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index ffc3a65bef..cea39d0adf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -204,13 +204,13 @@ def decode(ids): # fmt: off # System message, multi-turn conversations mt_ids = tokenize(test_data['multi_turn_sys']) - assert decode(mt_ids) == ' [INST]lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' - assert mt_ids == [1, 518, 25580, 29962, 29880, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] # System message, single-turn conversations st_ids = tokenize(test_data['single_turn_sys']) - assert decode(st_ids) == ' [INST]lorem\nabc [/INST] ipsum' - assert st_ids == [1, 518, 25580, 29962, 29880, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] # No system message, single-turn ns_ids = tokenize(test_data['single_turn_no_sys']) From c37deb4a951eb00e619fc5f26b2425e2c79e5787 Mon Sep 17 00:00:00 2001 From: hamelsmu Date: Tue, 19 Dec 2023 17:58:57 -0800 Subject: [PATCH 3/3] remove elif --- .../fastchat_conversation_turns.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 782d41d8ff..068261da36 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -30,7 +30,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + if self.sep_style == SeparatorStyle.ADD_COLON_TWO: seps = [self.sep, self.sep2] yield "", system_prompt + seps[0] for i, (role, message) in enumerate(self.messages): @@ -39,7 +39,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: yield "", system_prompt + self.sep for role, message in self.messages: if message: @@ -47,7 +47,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ": ", "" # must be end with a space return - elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: yield "", "" if system_prompt == "" else system_prompt + self.sep for role, message in self.messages: if message: @@ -55,7 +55,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + "\n", "" return - elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: yield "", system_prompt for role, message in self.messages: if message: @@ -63,7 +63,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + if self.sep_style == SeparatorStyle.NO_COLON_TWO: seps = [self.sep, self.sep2] yield "", system_prompt for i, (role, message) in enumerate(self.messages): @@ -72,7 +72,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - elif self.sep_style == SeparatorStyle.RWKV: + if self.sep_style == SeparatorStyle.RWKV: yield "", system_prompt for i, (role, message) in enumerate(self.messages): if message: @@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - elif self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": + if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.system_message: if self.messages: # For llama, the system message is incorporated into the first human instruction @@ -101,7 +101,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - elif self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": + if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": contains_sys_msg = False if self.system_message: contains_sys_msg = True @@ -123,7 +123,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return - elif self.sep_style == SeparatorStyle.CHATGLM: + if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 round_add_n = 1 if self.name == "chatglm2" else 0 @@ -139,7 +139,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield f"{role}:", "" return - elif self.sep_style == SeparatorStyle.CHATML: + if self.sep_style == SeparatorStyle.CHATML: yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" for role, message in self.messages: if message: @@ -147,7 +147,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + "\n", "" return - elif self.sep_style == SeparatorStyle.CHATINTERN: + if self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] yield "", system_prompt @@ -158,7 +158,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - elif self.sep_style == SeparatorStyle.DOLLY: + if self.sep_style == SeparatorStyle.DOLLY: seps = [self.sep, self.sep2] yield "", system_prompt for i, (role, message) in enumerate(self.messages): @@ -168,7 +168,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":\n", "" return - elif self.sep_style == SeparatorStyle.PHOENIX: + if self.sep_style == SeparatorStyle.PHOENIX: yield "", system_prompt for role, message in self.messages: if message: @@ -176,7 +176,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ": " + "", "" return - elif self.sep_style == SeparatorStyle.ROBIN: + if self.sep_style == SeparatorStyle.ROBIN: yield "", system_prompt + self.sep for role, message in self.messages: if message: @@ -184,7 +184,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":\n", "" return - elif self.sep_style == SeparatorStyle.FALCON_CHAT: + if self.sep_style == SeparatorStyle.FALCON_CHAT: if self.system_message: yield "", system_prompt + self.sep for role, message in self.messages: