From 5ada140ff02edb18dfbf9b2cdc08c13203cb0e7d Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Thu, 14 Dec 2023 10:03:59 -0800 Subject: [PATCH] Fix prompt assembly for llama (#952) * start at index 0 * add test to check for missing turns * apply black * Update test_prompt_tokenizers.py * Update src/axolotl/monkeypatch/fastchat_conversation_turns.py Co-authored-by: Motoki Wu * fix linting * apply black * add more tests for llama/sharegpt * make logic clearer --------- Co-authored-by: Motoki Wu --- .../fastchat_conversation_turns.py | 17 +++-- tests/test_prompt_tokenizers.py | 70 +++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 19313fb7e2..e1065a950f 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements yield role + ":", "" return if self.sep_style == SeparatorStyle.LLAMA2: - seps = [self.sep, self.sep2] if self.system_message: + if self.messages: + # For llama, the system message is incorporated into the first human instruction + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt += first_msg + self.messages.pop(0) yield "", system_prompt - else: - yield "", "[INST] " - for i, (role, message) in enumerate(self.messages[1:]): + for i, (role, message) in enumerate(self.messages): if message: - yield role + " ", message + seps[i % 2] + if (i % 2 == 0 and not self.system_message) or ( + i % 2 != 0 and self.system_message + ): + role = " " + role + yield role + " ", message else: yield role, "" return diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 0635bd718b..6e57ffb370 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -114,6 +114,76 @@ def test_sharegpt_warnings_turns(self): in self._caplog.records[0].message ) + def test_sharegpt_llama(self): + "Make sure the sharegpt/llama is tokenized and formatted correctly." + prompter = ShareGPTPrompterV2(conversation="llama-2") + strat = ShareGPTPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + + def tokenize(conv): + return strat.tokenize_prompt(conv)["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # Multi-turn conversations + multi_turn_conv = { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + } + # fmt: off + mt_ids = tokenize(multi_turn_conv) + assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # Single-turn conversations + single_turn_conv = { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + } + + st_ids = tokenize(single_turn_conv) + assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + no_sys_conv = { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + } + + ns_ids = tokenize(no_sys_conv) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + no_sys_mt_conv = { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + } + ns_mt_ids = tokenize(no_sys_mt_conv) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + def test_sharegpt_changes_roles(self): conversation = { "roles": ["USER", "CHARACTER"],