From 1623a5055f532313cd0331e17a1fbe543b98e0ab Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 18 Jan 2024 02:44:15 +0900 Subject: [PATCH 01/15] feat(prompt): support multiple roles for sharegpt --- src/axolotl/prompt_strategies/sharegpt.py | 156 +++++++++++++++++++++- src/axolotl/prompters.py | 59 ++++++++ 2 files changed, 213 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 15bfee8c47..e6a2c19cb3 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,10 +1,23 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" +import copy +import logging from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 +from axolotl.prompt_tokenizers import ( + InvalidDataException, + ShareGPTPromptTokenizingStrategy, + parse_tokenized_to_result, + tokenize_prompt_default, +) +from axolotl.prompters import ( + IGNORE_TOKEN_ID, + ShareGPTPrompterV2, + ShareGPTPrompterV2MultiRole, +) + +LOG = logging.getLogger("axolotl") def register_chatml_template(system_message=None): @@ -77,6 +90,28 @@ def load_guanaco(tokenizer, cfg): ) +def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None + field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + strategy = MultiRoleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2MultiRole( + conversation=conversation, + role_key_model=field_model, + role_key_human=field_human, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + + return strategy + + class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ basic sharegpt strategy to grab conversations from the sample row @@ -143,3 +178,120 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["content"]} for t in conversations ] return turns + + +class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy for support of multi-role + """ + + def tokenize_prompt(self, prompt): + # Initial values. We will append to these as we go through the conversation. + result, current_len = tokenize_prompt_default() + conversation: Conversation = ( + self.prompter._conversation.copy() # pylint: disable=protected-access + ) + user, assistant = conversation.roles + + input_roles = { + "human", + "funcresponse", + "funccaller", + "tool", + "tool_response", + user, + } + output_roles = {"gpt", "tool_caller", assistant} + + # support for custom roles from the dataset, only useful for vicuna style prompts/roles + role_remap = [] + if ( + conversation.name == "vicuna_v1.1" + and "roles" in prompt + and len(prompt["roles"]) >= 2 + ): + role_remap = [ + {"from": conversation.roles[0], "to": prompt["roles"][0]}, + {"from": conversation.roles[1], "to": prompt["roles"][1]}, + ] + + try: + for _, part in enumerate( + self.prompter.build_prompt(self.get_conversation_thread(prompt)) + ): + if not isinstance(part, tuple): + LOG.warning(f"expected tuple, got {part}") + continue + + role, content = part + + # Uses "in" because role contains extra characters + input_turn = any(r in role.lower() for r in input_roles) + output_turn = any(r in role.lower() for r in output_roles) + + if input_turn: + role = ( + role.replace(role_remap[0]["from"], role_remap[0]["to"]) + if role_remap + else role + ) + turn = role + content + # this is still the user query, we should + if not content.strip(): + LOG.warning(f"user turn has empty text: {prompt}") + res = self._tokenize( + turn, + add_eos_token=False, + strip_bos_token=True, + ) + if self.train_on_inputs: + labels = copy.deepcopy(res["input_ids"]) + else: + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif output_turn: + role = ( + role.replace(role_remap[1]["from"], role_remap[1]["to"]) + if role_remap + else role + ) + turn = role + content + # this should be the assistant response, should end with an eos token + if not content.strip(): + LOG.warning(f"assistant turn has empty text: {prompt}") + add_eos_token = not ( + conversation.name == "chatml" + and conversation.sep == self.tokenizer.eos_token + ) + res = self._tokenize( + turn, + add_eos_token=add_eos_token, + strip_bos_token=True, + ) + role_res = self._tokenize( + role.rstrip(), + add_eos_token=False, + strip_bos_token=True, + ) + labels = copy.deepcopy(res["input_ids"]) + if not self.train_on_inputs: + # mask out role tokens from the labels + len_role = len(role_res["input_ids"]) + labels[:len_role] = [IGNORE_TOKEN_ID] * min( + len_role, len(labels) + ) + else: + LOG.warning(f"unhandled role: {role}") + continue + + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) + return result + except (KeyError, AssertionError, IndexError) as err: + raise InvalidDataException(str(err)) from err diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 748db1a162..05bf53510e 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -355,6 +355,65 @@ def __init__( ) +CONVERSATION_ROLE_FORMAT = { + "chatml": "<|im_start|>{ROLE}", + "zephyr": "<|{ROLE}|>", +} + + +class ShareGPTPrompterV2MultiRole(ShareGPTPrompterV2): + """ + An multi-role V2 prompter that generates prompts for the ShareGPT that supports multi-role + """ + + def _build_result(self, source): + if len(source) < 2: + # If there isn't a back and forth conversation, ignore it + # also happens on the data splitting leaving empty conversations + raise IndexError( + f"A conversation entry has less than 2 messages :\n{source}" + ) + + conv = self._conversation.copy() + + # Add the conversation system prompt if provided, otherwise use the default one + if source[0]["from"] == "system": + conv.set_system_message(source[0]["value"]) + source.pop(0) + + roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} + + try: + # Apply prompt templates + if source[0]["from"] not in roles: + # Skip the first one if it is not from human + source = source[1:] + except IndexError as err: + # sometimes there is a bing or system chat + raise err + + conv.messages = [] + for _, sentence in enumerate(source): + from_role = sentence["from"] + if from_role in roles: + role = roles[from_role] + else: + if self._conversation.name not in CONVERSATION_ROLE_FORMAT: + raise NotImplementedError( + f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." + ) + + role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( + ROLE=from_role + ) + + if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): + LOG.warning(f"Roles did not alternate: {sentence}") + conv.append_message(role, sentence["value"]) + + return conv.get_turns() + + class UnsupportedPrompter(Prompter): """ A dummy class for custom prompters From 85ddde2ada7b84678ad473c153745743b13b7b05 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 18 Jan 2024 18:07:52 +0900 Subject: [PATCH 02/15] fix: add handling of empty role back --- src/axolotl/prompt_strategies/sharegpt.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index e6a2c19cb3..dbd46e82bc 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -280,6 +280,17 @@ def tokenize_prompt(self, prompt): labels[:len_role] = [IGNORE_TOKEN_ID] * min( len_role, len(labels) ) + elif role == "": + turn = content + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + turn, add_eos_token=False, strip_bos_token=False + ) + if self.train_on_inputs: + labels = copy.deepcopy(res["input_ids"]) + else: + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: LOG.warning(f"unhandled role: {role}") continue From 33bcf57a5e1e0abbae2eefc6d0048471a2df6c9f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 Feb 2024 22:12:53 +0900 Subject: [PATCH 03/15] feat: rebased and allowed more dynamic roles via config --- README.md | 6 +- src/axolotl/prompt_strategies/sharegpt.py | 166 +--------------------- src/axolotl/prompt_tokenizers.py | 34 +++-- src/axolotl/prompters.py | 89 ++++-------- 4 files changed, 61 insertions(+), 234 deletions(-) diff --git a/README.md b/README.md index 3c9f030007..754c017719 100644 --- a/README.md +++ b/README.md @@ -597,9 +597,13 @@ datasets: train_on_split: train # Optional[str] name of dataset split to load from # Optional[str] fastchat conversation type, only used with type: sharegpt - conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py + conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py field_human: # Optional[str]. Human key to use for conversation. field_model: # Optional[str]. Assistant key to use for conversation. + # Add additional keys from your dataset as input or output roles + roles: + input: # Optional[List[str]]. These will be masked based on train_on_input + output: # Optional[List[str]]. # Custom user instruction prompt - path: repo diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index dbd46e82bc..bc5f16acb5 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,21 +1,11 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" -import copy import logging from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template -from axolotl.prompt_tokenizers import ( - InvalidDataException, - ShareGPTPromptTokenizingStrategy, - parse_tokenized_to_result, - tokenize_prompt_default, -) -from axolotl.prompters import ( - IGNORE_TOKEN_ID, - ShareGPTPrompterV2, - ShareGPTPrompterV2MultiRole, -) +from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy +from axolotl.prompters import ShareGPTPrompterV2 LOG = logging.getLogger("axolotl") @@ -40,11 +30,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, role_key_model=field_model, role_key_human=field_human, + roles=roles, ), tokenizer, cfg.train_on_inputs, @@ -90,28 +82,6 @@ def load_guanaco(tokenizer, cfg): ) -def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None - field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - strategy = MultiRoleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2MultiRole( - conversation=conversation, - role_key_model=field_model, - role_key_human=field_human, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - - return strategy - - class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ basic sharegpt strategy to grab conversations from the sample row @@ -178,131 +148,3 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["content"]} for t in conversations ] return turns - - -class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): - """ - sharegpt strategy for support of multi-role - """ - - def tokenize_prompt(self, prompt): - # Initial values. We will append to these as we go through the conversation. - result, current_len = tokenize_prompt_default() - conversation: Conversation = ( - self.prompter._conversation.copy() # pylint: disable=protected-access - ) - user, assistant = conversation.roles - - input_roles = { - "human", - "funcresponse", - "funccaller", - "tool", - "tool_response", - user, - } - output_roles = {"gpt", "tool_caller", assistant} - - # support for custom roles from the dataset, only useful for vicuna style prompts/roles - role_remap = [] - if ( - conversation.name == "vicuna_v1.1" - and "roles" in prompt - and len(prompt["roles"]) >= 2 - ): - role_remap = [ - {"from": conversation.roles[0], "to": prompt["roles"][0]}, - {"from": conversation.roles[1], "to": prompt["roles"][1]}, - ] - - try: - for _, part in enumerate( - self.prompter.build_prompt(self.get_conversation_thread(prompt)) - ): - if not isinstance(part, tuple): - LOG.warning(f"expected tuple, got {part}") - continue - - role, content = part - - # Uses "in" because role contains extra characters - input_turn = any(r in role.lower() for r in input_roles) - output_turn = any(r in role.lower() for r in output_roles) - - if input_turn: - role = ( - role.replace(role_remap[0]["from"], role_remap[0]["to"]) - if role_remap - else role - ) - turn = role + content - # this is still the user query, we should - if not content.strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - if self.train_on_inputs: - labels = copy.deepcopy(res["input_ids"]) - else: - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif output_turn: - role = ( - role.replace(role_remap[1]["from"], role_remap[1]["to"]) - if role_remap - else role - ) - turn = role + content - # this should be the assistant response, should end with an eos token - if not content.strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - add_eos_token = not ( - conversation.name == "chatml" - and conversation.sep == self.tokenizer.eos_token - ) - res = self._tokenize( - turn, - add_eos_token=add_eos_token, - strip_bos_token=True, - ) - role_res = self._tokenize( - role.rstrip(), - add_eos_token=False, - strip_bos_token=True, - ) - labels = copy.deepcopy(res["input_ids"]) - if not self.train_on_inputs: - # mask out role tokens from the labels - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min( - len_role, len(labels) - ) - elif role == "": - turn = content - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - if self.train_on_inputs: - labels = copy.deepcopy(res["input_ids"]) - else: - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {role}") - continue - - # pylint: disable=duplicate-code - result, current_len = parse_tokenized_to_result( - result, - current_len, - res, - labels, - pad_token_id=self.tokenizer.pad_token_id, - ) - return result - except (KeyError, AssertionError, IndexError) as err: - raise InvalidDataException(str(err)) from err diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index a5c243f7e6..e2dc756297 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -11,7 +11,7 @@ from axolotl.monkeypatch.fastchat_conversation_turns import ( add_get_turns_to_conversation, ) -from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.prompters import IGNORE_TOKEN_ID, Prompter LOG = logging.getLogger("axolotl") @@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC): def __init__( self, - prompter, + prompter: Prompter, tokenizer, train_on_inputs: bool = False, sequence_len: int = 2048, @@ -340,6 +340,19 @@ def tokenize_prompt(self, prompt): self.prompter._conversation.copy() # pylint: disable=protected-access ) + input_roles = {conversation.roles[0]} + output_roles = {conversation.roles[1]} + + # Add roles from the config + if self.prompter.roles: + if "input" in self.prompter.roles and self.prompter.roles["input"]: + for role in self.prompter.roles["input"]: + input_roles.add(role) + + if "output" in self.prompter.roles and self.prompter.roles["output"]: + for role in self.prompter.roles["output"]: + output_roles.add(role) + # support for custom roles from the dataset, only useful for vicuna style prompts/roles role_remap = [] if ( @@ -360,11 +373,19 @@ def tokenize_prompt(self, prompt): LOG.warning(f"expected tuple, got {part}") continue - user, assistant = conversation.roles role, content = part # Uses "in" because role contains extra characters - if user in role: + input_turn = any(r in role.lower() for r in input_roles) + output_turn = any(r in role.lower() for r in output_roles) + empty_role = role.strip() == "" + + if not any([input_turn, output_turn, empty_role]): + LOG.warning(f"unhandled role: {role}") + continue + + # Uses "in" because role contains extra characters + if input_turn: role = ( role.replace(role_remap[0]["from"], role_remap[0]["to"]) if role_remap @@ -384,7 +405,7 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif assistant in role: + elif output_turn: role = ( role.replace(role_remap[1]["from"], role_remap[1]["to"]) if role_remap @@ -426,9 +447,6 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {role}") - continue # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 05bf53510e..cba057ca87 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -259,6 +259,11 @@ def __repr__(self) -> str: "Role did not alternate between turns (gpt and human). Please check your data." ) +CONVERSATION_ROLE_FORMAT = { + "chatml": "<|im_start|>{ROLE}", + "zephyr": "<|{ROLE}|>", +} + class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods """ @@ -274,6 +279,7 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + roles: Optional[dict] = None, ): if conversation: if isinstance(conversation, Conversation): @@ -287,6 +293,8 @@ def __init__( if role_key_model: self.role_key_model = role_key_model + self.roles = roles + def _build_result(self, source): if len(source) < 2: # If there isn't a back and forth conversation, ignore it @@ -315,11 +323,23 @@ def _build_result(self, source): conv.messages = [] for _, sentence in enumerate(source): - role = roles[sentence["from"]] - if len(conv.messages) > 0 and ( - (role == conv.messages[-1][0]) or (role not in conv.roles) - ): + from_role = sentence["from"] + if from_role in roles: + role = roles[from_role] + else: + if self._conversation.name not in CONVERSATION_ROLE_FORMAT: + raise NotImplementedError( + f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." + "Please help us by creating an Issue to add support for this role." + ) + + role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( + ROLE=from_role + ) + + if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") + conv.append_message(role, sentence["value"]) return conv.get_turns() @@ -347,73 +367,16 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + roles: Optional[dict] = None, ): super().__init__( conversation=conversation, role_key_human=role_key_human, role_key_model=role_key_model, + roles=roles, ) -CONVERSATION_ROLE_FORMAT = { - "chatml": "<|im_start|>{ROLE}", - "zephyr": "<|{ROLE}|>", -} - - -class ShareGPTPrompterV2MultiRole(ShareGPTPrompterV2): - """ - An multi-role V2 prompter that generates prompts for the ShareGPT that supports multi-role - """ - - def _build_result(self, source): - if len(source) < 2: - # If there isn't a back and forth conversation, ignore it - # also happens on the data splitting leaving empty conversations - raise IndexError( - f"A conversation entry has less than 2 messages :\n{source}" - ) - - conv = self._conversation.copy() - - # Add the conversation system prompt if provided, otherwise use the default one - if source[0]["from"] == "system": - conv.set_system_message(source[0]["value"]) - source.pop(0) - - roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} - - try: - # Apply prompt templates - if source[0]["from"] not in roles: - # Skip the first one if it is not from human - source = source[1:] - except IndexError as err: - # sometimes there is a bing or system chat - raise err - - conv.messages = [] - for _, sentence in enumerate(source): - from_role = sentence["from"] - if from_role in roles: - role = roles[from_role] - else: - if self._conversation.name not in CONVERSATION_ROLE_FORMAT: - raise NotImplementedError( - f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." - ) - - role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( - ROLE=from_role - ) - - if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): - LOG.warning(f"Roles did not alternate: {sentence}") - conv.append_message(role, sentence["value"]) - - return conv.get_turns() - - class UnsupportedPrompter(Prompter): """ A dummy class for custom prompters From b54869c7ff5d1527cd2aec7c78ab30385faf8d9e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 Feb 2024 22:31:15 +0900 Subject: [PATCH 04/15] fix: variable --- src/axolotl/prompt_tokenizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index e2dc756297..7fa41ca43a 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -436,7 +436,7 @@ def tokenize_prompt(self, prompt): labels[:len_role] = [IGNORE_TOKEN_ID] * min( len_role, len(labels) ) - elif role == "": + elif empty_role: turn = content # this is only ever the first part, should include the bos token and the user query res = self._tokenize( From 8751ffb08db060247f7871caed40e70cdeb9c2ab Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 Feb 2024 22:31:26 +0900 Subject: [PATCH 05/15] chore: update message --- src/axolotl/prompters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index cba057ca87..5d419389ed 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -330,7 +330,7 @@ def _build_result(self, source): if self._conversation.name not in CONVERSATION_ROLE_FORMAT: raise NotImplementedError( f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." - "Please help us by creating an Issue to add support for this role." + "Please help us by creating an Issue to add support for this conversation type." ) role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( From 24591e88ef5ffd97a15a0568971008b508ff4bcd Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 23 Feb 2024 17:08:02 +0900 Subject: [PATCH 06/15] feat: add vicuna format --- src/axolotl/prompt_tokenizers.py | 4 ++-- src/axolotl/prompters.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 7fa41ca43a..b0f479f044 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -376,8 +376,8 @@ def tokenize_prompt(self, prompt): role, content = part # Uses "in" because role contains extra characters - input_turn = any(r in role.lower() for r in input_roles) - output_turn = any(r in role.lower() for r in output_roles) + input_turn = any(r.lower() in role.lower() for r in input_roles) + output_turn = any(r.lower() in role.lower() for r in output_roles) empty_role = role.strip() == "" if not any([input_turn, output_turn, empty_role]): diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 5d419389ed..6883ffec85 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -262,6 +262,7 @@ def __repr__(self) -> str: CONVERSATION_ROLE_FORMAT = { "chatml": "<|im_start|>{ROLE}", "zephyr": "<|{ROLE}|>", + "vicuna_v1.1": "{ROLE}", } From 9e171a98adb6e3068d371d6b3b5fb031974ad11c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 23 Feb 2024 17:49:58 +0900 Subject: [PATCH 07/15] fix: JSON serializable error --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a5986fa4ff..0ccb45be4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d71 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.26.1 -deepspeed>=0.13.1 +deepspeed==0.13.1 addict fire PyYAML>=6.0 From 2b9d66c56bf9aa781700844db271a408cbdc4818 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:34:53 +0900 Subject: [PATCH 08/15] fix: typing --- src/axolotl/prompters.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 49f455456d..e670448298 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -273,12 +273,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods role_key_human = "human" role_key_model = "gpt" - # Optional, only used for tool usage datasets. - role_key_tool = None - + role_key_tool: Optional[str] = None # Optional, role input/output mapping - roles = None + roles: Optional[dict] = None def __init__( self, @@ -302,6 +300,7 @@ def __init__( self.role_key_model = role_key_model if role_key_tool: self.role_key_tool = role_key_tool + print(roles) if roles: self.roles = roles From ad70d344f3e22e4c57e1aa11178fb41e9017a58f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:35:20 +0900 Subject: [PATCH 09/15] fix: don't remap for unknown keys --- src/axolotl/prompt_strategies/sharegpt.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 1cfa078c4d..0ff022eee1 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,4 +1,5 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" + import logging from typing import Any, Dict, Optional @@ -45,6 +46,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None + print("roles", roles) strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, @@ -146,7 +148,12 @@ def get_conversation_thread(self, prompt): "system": "system", } turns = [ - {"from": role_map[t[role_key]], "value": t[value_key]} + { + "from": ( + role_map[t[role_key]] if t[role_key] in role_map else t[role_key] + ), + "value": t[value_key], + } for t in conversations ] return turns From 303163299965a11a00ca61aaae0e157bedaf8e55 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:35:30 +0900 Subject: [PATCH 10/15] fix: add roles to pydantic --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d7e7b24de4..2bf544fd27 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -96,6 +96,8 @@ class SFTDataset(BaseModel): field_human: Optional[str] = None field_model: Optional[str] = None + roles: Optional[Dict[str, List[str]]] = None + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" From c2738b36102ec2f7e9a622666e5fb062bf5ea843 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:36:01 +0900 Subject: [PATCH 11/15] feat: add test --- tests/prompt_strategies/test_sharegpt.py | 57 ++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index c9290b220a..77a26a4bed 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset(): ) +@pytest.fixture(name="multi_role_dataset") +def fixture_multi_role_dataset(): + return Dataset.from_list( + [ + { + "conversations": [ + { + "from": "system", + "value": "use get_weather(city) to get the weather for a city", + }, + { + "from": "human", + "value": "hello, what's the weather in New York?", + }, + { + "from": "gpt", + "value": "let me get that for you", + }, + { + "from": "tool", + "value": "get_weather(New York)", + }, + { + "from": "gpt", + "value": "the weather in New York is 70 degrees and sunny", + }, + ] + } + ] + ) + + @pytest.fixture(name="tokenizer") def fixture_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") @@ -196,3 +228,28 @@ def test_chatml_glaive(self, glaive_dataset, tokenizer): 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt ] # fmt: on + + def test_multi_role_dataset(self, multi_role_dataset, tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2(roles={"input": ["tool"]}), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, multi_role_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + print(labels) + # fmt: off + assert labels == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, -100, -100, 1346, 528, 625, 369, 354, 368, 2, 32000, # gpt + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool + -100, -100, -100, -100, 272, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 2, 32000, # gpt + ] + # fmt: on From 320312f3619221700e4aa84c190e91e78f64aeee Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:45:38 +0900 Subject: [PATCH 12/15] chore: remove leftover print --- src/axolotl/prompt_strategies/sharegpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 0ff022eee1..55bdd37b4f 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -46,7 +46,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None - print("roles", roles) strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, From 4b86dd8243e2b13d7ec4f285519978930d29e6d0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:46:22 +0900 Subject: [PATCH 13/15] chore: remove leftover comment --- src/axolotl/prompt_tokenizers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index f0b94d033d..bb13cf76dd 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -388,7 +388,6 @@ def tokenize_prompt(self, prompt): LOG.warning(f"unhandled role: {role}") continue - # Uses "in" because role contains extra characters if input_turn: role = ( role.replace(role_remap[0]["from"], role_remap[0]["to"]) From f00f63bb36744dfe2fca8456cb4965fb73e3dfe4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Mar 2024 13:47:13 +0900 Subject: [PATCH 14/15] chore: remove print --- src/axolotl/prompters.py | 1 - tests/prompt_strategies/test_sharegpt.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index e670448298..2b6b4f8577 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -300,7 +300,6 @@ def __init__( self.role_key_model = role_key_model if role_key_tool: self.role_key_tool = role_key_tool - print(roles) if roles: self.roles = roles diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 77a26a4bed..8ca9b3e0e6 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -242,7 +242,6 @@ def test_multi_role_dataset(self, multi_role_dataset, tokenizer): ) labels = dataset_wrapper[0]["labels"] - print(labels) # fmt: off assert labels == [ -100, # bos From cc545db4b2775c03657da05f41cd4d2195dd2e0e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 Mar 2024 09:10:05 +0900 Subject: [PATCH 15/15] fix: update test to use chatml --- tests/prompt_strategies/test_sharegpt.py | 26 +++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 8ca9b3e0e6..19d63eac83 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -231,7 +231,7 @@ def test_chatml_glaive(self, glaive_dataset, tokenizer): def test_multi_role_dataset(self, multi_role_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(roles={"input": ["tool"]}), + ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}), tokenizer, False, # train_on_inputs 2048, # sequence_len @@ -241,14 +241,26 @@ def test_multi_role_dataset(self, multi_role_dataset, tokenizer): strategy, multi_role_dataset, process_count=1 ) + input_ids = dataset_wrapper[0]["input_ids"] + # fmt: off + assert input_ids == [ + 1, # bos + 32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system + 32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human + 32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt + 32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool + 32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt + ] + # fmt: on + labels = dataset_wrapper[0]["labels"] # fmt: off assert labels == [ - -100, # bos - -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system - -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, -100, -100, 1346, 528, 625, 369, 354, 368, 2, 32000, # gpt - -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool - -100, -100, -100, -100, 272, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 2, 32000, # gpt + -100, # bos + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool + -100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt ] # fmt: on