diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py
index e1065a950f..068261da36 100644
--- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py
+++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py
@@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements
else:
yield role + ":", ""
return
- if self.sep_style == SeparatorStyle.LLAMA2:
+ if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
@@ -101,6 +101,28 @@ def get_turns( # pylint: disable=too-many-return-statements
else:
yield role, ""
return
+ if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
+ contains_sys_msg = False
+ if self.system_message:
+ contains_sys_msg = True
+ if self.messages:
+ # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
+ first_role, first_msg = self.messages[0]
+ if first_role == self.roles[0]:
+ system_prompt = self.system_template.format(
+ system_message=" " + self.system_message
+ )
+ system_prompt += first_msg
+ self.messages.pop(0)
+ yield "", system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message and i == 0 and not contains_sys_msg:
+ yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction.
+ elif message:
+ yield role + " ", message
+ else:
+ yield role, ""
+ return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index 6e57ffb370..cea39d0adf 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -2,6 +2,7 @@
import json
import logging
import unittest
+from copy import deepcopy
from pathlib import Path
from typing import Optional
@@ -25,6 +26,50 @@
LOG = logging.getLogger("axolotl")
+test_data = {
+ "multi_turn_sys": {
+ "conversations": [
+ {"from": "system", "value": "lorem"},
+ {"from": "human", "value": "abc"},
+ {"from": "gpt", "value": "ipsum"},
+ {"from": "human", "value": "123"},
+ {"from": "gpt", "value": "sit"},
+ ]
+ },
+ "single_turn_sys": {
+ "conversations": [
+ {"from": "system", "value": "lorem"},
+ {"from": "human", "value": "abc"},
+ {"from": "gpt", "value": "ipsum"},
+ ]
+ },
+ "single_turn_no_sys": {
+ "conversations": [
+ {"from": "human", "value": "abc"},
+ {"from": "gpt", "value": "ipsum"},
+ ]
+ },
+ "multi_turn_no_sys": {
+ "conversations": [
+ {"from": "human", "value": "abc"},
+ {"from": "gpt", "value": "ipsum"},
+ {"from": "human", "value": "123"},
+ {"from": "gpt", "value": "sit"},
+ ]
+ },
+}
+
+
+def prompt_strat(conversation, tokenizer):
+ "Helper function to create a prompt strategy for testing."
+ prompter = ShareGPTPrompterV2(conversation=conversation)
+ return ShareGPTPromptTokenizingStrategy(
+ prompter,
+ tokenizer,
+ False,
+ 2048,
+ )
+
class TestPromptTokenizationStrategies(unittest.TestCase):
"""
@@ -116,74 +161,68 @@ def test_sharegpt_warnings_turns(self):
def test_sharegpt_llama(self):
"Make sure the sharegpt/llama is tokenized and formatted correctly."
- prompter = ShareGPTPrompterV2(conversation="llama-2")
- strat = ShareGPTPromptTokenizingStrategy(
- prompter,
- self.tokenizer,
- False,
- 2048,
- )
+ strat = prompt_strat("llama-2", self.tokenizer)
def tokenize(conv):
- return strat.tokenize_prompt(conv)["input_ids"]
+ return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
def decode(ids):
return strat.tokenizer.decode(ids)
- # Multi-turn conversations
- multi_turn_conv = {
- "conversations": [
- {"from": "system", "value": "lorem"},
- {"from": "human", "value": "abc"},
- {"from": "gpt", "value": "ipsum"},
- {"from": "human", "value": "123"},
- {"from": "gpt", "value": "sit"},
- ]
- }
# fmt: off
- mt_ids = tokenize(multi_turn_conv)
+ # System message, multi-turn conversations
+ mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit'
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
- # Single-turn conversations
- single_turn_conv = {
- "conversations": [
- {"from": "system", "value": "lorem"},
- {"from": "human", "value": "abc"},
- {"from": "gpt", "value": "ipsum"},
- ]
- }
-
- st_ids = tokenize(single_turn_conv)
+ # System message, single-turn conversations
+ st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum'
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
- no_sys_conv = {
- "conversations": [
- {"from": "human", "value": "abc"},
- {"from": "gpt", "value": "ipsum"},
- ]
- }
-
- ns_ids = tokenize(no_sys_conv)
+ ns_ids = tokenize(test_data['single_turn_no_sys'])
assert decode(ns_ids) == ' [INST] abc [/INST] ipsum'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
# No system message, multi-turn
- no_sys_mt_conv = {
- "conversations": [
- {"from": "human", "value": "abc"},
- {"from": "gpt", "value": "ipsum"},
- {"from": "human", "value": "123"},
- {"from": "gpt", "value": "sit"},
- ]
- }
- ns_mt_ids = tokenize(no_sys_mt_conv)
+ ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on
+ def test_sharegpt_mistral(self):
+ "Make sure the sharegpt/mistral is tokenized and formatted correctly."
+ strat = prompt_strat("mistral", self.tokenizer)
+
+ def tokenize(conv):
+ return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
+
+ def decode(ids):
+ return strat.tokenizer.decode(ids)
+
+ # fmt: off
+ # System message, multi-turn conversations
+ mt_ids = tokenize(test_data['multi_turn_sys'])
+ assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit'
+ assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
+
+ # System message, single-turn conversations
+ st_ids = tokenize(test_data['single_turn_sys'])
+ assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum'
+ assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
+
+ # No system message, single-turn
+ ns_ids = tokenize(test_data['single_turn_no_sys'])
+ assert decode(ns_ids) == ' [INST] abc [/INST] ipsum'
+ assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
+
+ # No system message, multi-turn
+ ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
+ assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit'
+ assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
+ # fmt: on
+
def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],