From 0e03717dea1aa36621b5db883cfb625ec7129037 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 19 Mar 2024 20:51:49 +0900 Subject: [PATCH] Feat: Add sharegpt multirole (#1137) * feat(prompt): support multiple roles for sharegpt * fix: add handling of empty role back * feat: rebased and allowed more dynamic roles via config * fix: variable * chore: update message * feat: add vicuna format * fix: JSON serializable error * fix: typing * fix: don't remap for unknown keys * fix: add roles to pydantic * feat: add test * chore: remove leftover print * chore: remove leftover comment * chore: remove print * fix: update test to use chatml --- README.md | 6 +- src/axolotl/prompt_strategies/sharegpt.py | 12 +++- src/axolotl/prompt_tokenizers.py | 49 +++++++------ src/axolotl/prompters.py | 35 ++++++++-- .../config/models/input/v0_4_1/__init__.py | 2 + tests/prompt_strategies/test_sharegpt.py | 68 +++++++++++++++++++ 6 files changed, 146 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 41521f081c..95e0b57ad9 100644 --- a/README.md +++ b/README.md @@ -651,9 +651,13 @@ datasets: train_on_split: train # Optional[str] name of dataset split to load from # Optional[str] fastchat conversation type, only used with type: sharegpt - conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py + conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py field_human: # Optional[str]. Human key to use for conversation. field_model: # Optional[str]. Assistant key to use for conversation. + # Add additional keys from your dataset as input or output roles + roles: + input: # Optional[List[str]]. These will be masked based on train_on_input + output: # Optional[List[str]]. # Custom user instruction prompt - path: repo diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 7a7f61a8ee..55bdd37b4f 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,5 +1,6 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" +import logging from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template @@ -11,6 +12,8 @@ merge_consecutive_messages, ) +LOG = logging.getLogger("axolotl") + def register_chatml_template(system_message=None): system_message = system_message or "You are a helpful assistant." @@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, role_key_model=field_model, role_key_human=field_human, + roles=roles, ), tokenizer, cfg.train_on_inputs, @@ -142,7 +147,12 @@ def get_conversation_thread(self, prompt): "system": "system", } turns = [ - {"from": role_map[t[role_key]], "value": t[value_key]} + { + "from": ( + role_map[t[role_key]] if t[role_key] in role_map else t[role_key] + ), + "value": t[value_key], + } for t in conversations ] return turns diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 7e62a0cd4c..bb13cf76dd 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -11,7 +11,7 @@ from axolotl.monkeypatch.fastchat_conversation_turns import ( add_get_turns_to_conversation, ) -from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.prompters import IGNORE_TOKEN_ID, Prompter LOG = logging.getLogger("axolotl") @@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC): def __init__( self, - prompter, + prompter: Prompter, tokenizer, train_on_inputs: bool = False, sequence_len: int = 2048, @@ -340,6 +340,23 @@ def tokenize_prompt(self, prompt): self.prompter._conversation.copy() # pylint: disable=protected-access ) + input_roles = {conversation.roles[0]} + output_roles = {conversation.roles[1]} + + if len(conversation.roles) == 3: + tool_role_label = conversation.roles[2] + input_roles.add(tool_role_label) + + # Add roles from the config + if self.prompter.roles: + if "input" in self.prompter.roles and self.prompter.roles["input"]: + for role in self.prompter.roles["input"]: + input_roles.add(role) + + if "output" in self.prompter.roles and self.prompter.roles["output"]: + for role in self.prompter.roles["output"]: + output_roles.add(role) + # support for custom roles from the dataset, only useful for vicuna style prompts/roles role_remap = [] if ( @@ -360,19 +377,18 @@ def tokenize_prompt(self, prompt): LOG.warning(f"expected tuple, got {part}") continue - tool_role_label = None - if len(conversation.roles) == 3: - ( - user_role_label, - assistant_role_label, - tool_role_label, - ) = conversation.roles - else: - user_role_label, assistant_role_label = conversation.roles role, content = part # Uses "in" because role contains extra characters - if user_role_label in role: + input_turn = any(r.lower() in role.lower() for r in input_roles) + output_turn = any(r.lower() in role.lower() for r in output_roles) + empty_role = role.strip() == "" + + if not any([input_turn, output_turn, empty_role]): + LOG.warning(f"unhandled role: {role}") + continue + + if input_turn: role = ( role.replace(role_remap[0]["from"], role_remap[0]["to"]) if role_remap @@ -392,7 +408,7 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif assistant_role_label in role: + elif output_turn: role = ( role.replace(role_remap[1]["from"], role_remap[1]["to"]) if role_remap @@ -423,7 +439,7 @@ def tokenize_prompt(self, prompt): labels[:len_role] = [IGNORE_TOKEN_ID] * min( len_role, len(labels) ) - elif role == "": + elif empty_role: turn = content # this is only ever the first part, should include the bos token and the user query res = self._tokenize( @@ -434,11 +450,6 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif tool_role_label and tool_role_label in role: - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {role}") - continue # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index fa181f916d..2b6b4f8577 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -259,6 +259,12 @@ def __repr__(self) -> str: "Role did not alternate between turns (gpt and human). Please check your data." ) +CONVERSATION_ROLE_FORMAT = { + "chatml": "<|im_start|>{ROLE}", + "zephyr": "<|{ROLE}|>", + "vicuna_v1.1": "{ROLE}", +} + class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods """ @@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods role_key_human = "human" role_key_model = "gpt" # Optional, only used for tool usage datasets. - role_key_tool = None + role_key_tool: Optional[str] = None + # Optional, role input/output mapping + roles: Optional[dict] = None def __init__( self, @@ -277,6 +285,7 @@ def __init__( role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, role_key_tool: Optional[str] = None, + roles: Optional[dict] = None, ): if conversation: if isinstance(conversation, Conversation): @@ -291,6 +300,8 @@ def __init__( self.role_key_model = role_key_model if role_key_tool: self.role_key_tool = role_key_tool + if roles: + self.roles = roles def _build_result(self, source): if len(source) < 2: @@ -322,11 +333,23 @@ def _build_result(self, source): conv.messages = [] for _, sentence in enumerate(source): - role = roles[sentence["from"]] - if len(conv.messages) > 0 and ( - (role == conv.messages[-1][0]) or (role not in conv.roles) - ): + from_role = sentence["from"] + if from_role in roles: + role = roles[from_role] + else: + if self._conversation.name not in CONVERSATION_ROLE_FORMAT: + raise NotImplementedError( + f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." + "Please help us by creating an Issue to add support for this conversation type." + ) + + role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( + ROLE=from_role + ) + + if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") + conv.append_message(role, sentence["value"]) return conv.get_turns() @@ -354,11 +377,13 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + roles: Optional[dict] = None, ): super().__init__( conversation=conversation, role_key_human=role_key_human, role_key_model=role_key_model, + roles=roles, ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index b1c395bcc8..a68104bc4b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -96,6 +96,8 @@ class SFTDataset(BaseModel): field_human: Optional[str] = None field_model: Optional[str] = None + roles: Optional[Dict[str, List[str]]] = None + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index c9290b220a..19d63eac83 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset(): ) +@pytest.fixture(name="multi_role_dataset") +def fixture_multi_role_dataset(): + return Dataset.from_list( + [ + { + "conversations": [ + { + "from": "system", + "value": "use get_weather(city) to get the weather for a city", + }, + { + "from": "human", + "value": "hello, what's the weather in New York?", + }, + { + "from": "gpt", + "value": "let me get that for you", + }, + { + "from": "tool", + "value": "get_weather(New York)", + }, + { + "from": "gpt", + "value": "the weather in New York is 70 degrees and sunny", + }, + ] + } + ] + ) + + @pytest.fixture(name="tokenizer") def fixture_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") @@ -196,3 +228,39 @@ def test_chatml_glaive(self, glaive_dataset, tokenizer): 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt ] # fmt: on + + def test_multi_role_dataset(self, multi_role_dataset, tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, multi_role_dataset, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + # fmt: off + assert input_ids == [ + 1, # bos + 32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system + 32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human + 32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt + 32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool + 32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt + ] + # fmt: on + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool + -100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt + ] + # fmt: on