From 885f60335c8c82320051b8944cda1dfdb585d662 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 18 Jan 2024 02:44:15 +0900 Subject: [PATCH] feat(prompt): support multiple roles for sharegpt --- src/axolotl/prompt_strategies/sharegpt.py | 156 +++++++++++++++++++++- src/axolotl/prompters.py | 59 ++++++++ 2 files changed, 213 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index c026889682..aaa199661d 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,10 +1,23 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" +import copy +import logging from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 +from axolotl.prompt_tokenizers import ( + InvalidDataException, + ShareGPTPromptTokenizingStrategy, + parse_tokenized_to_result, + tokenize_prompt_default, +) +from axolotl.prompters import ( + IGNORE_TOKEN_ID, + ShareGPTPrompterV2, + ShareGPTPrompterV2MultiRole, +) + +LOG = logging.getLogger("axolotl") register_conv_template( Conversation( @@ -74,6 +87,28 @@ def load_guanaco(tokenizer, cfg): ) +def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None + field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + strategy = MultiRoleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2MultiRole( + conversation=conversation, + role_key_model=field_model, + role_key_human=field_human, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + + return strategy + + class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ basic sharegpt strategy to grab conversations from the sample row @@ -140,3 +175,120 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["content"]} for t in conversations ] return turns + + +class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy for support of multi-role + """ + + def tokenize_prompt(self, prompt): + # Initial values. We will append to these as we go through the conversation. + result, current_len = tokenize_prompt_default() + conversation: Conversation = ( + self.prompter._conversation.copy() # pylint: disable=protected-access + ) + user, assistant = conversation.roles + + input_roles = { + "human", + "funcresponse", + "funccaller", + "tool", + "tool_response", + user, + } + output_roles = {"gpt", "tool_caller", assistant} + + # support for custom roles from the dataset, only useful for vicuna style prompts/roles + role_remap = [] + if ( + conversation.name == "vicuna_v1.1" + and "roles" in prompt + and len(prompt["roles"]) >= 2 + ): + role_remap = [ + {"from": conversation.roles[0], "to": prompt["roles"][0]}, + {"from": conversation.roles[1], "to": prompt["roles"][1]}, + ] + + try: + for _, part in enumerate( + self.prompter.build_prompt(self.get_conversation_thread(prompt)) + ): + if not isinstance(part, tuple): + LOG.warning(f"expected tuple, got {part}") + continue + + role, content = part + + # Uses "in" because role contains extra characters + input_turn = any(r in role.lower() for r in input_roles) + output_turn = any(r in role.lower() for r in output_roles) + + if input_turn: + role = ( + role.replace(role_remap[0]["from"], role_remap[0]["to"]) + if role_remap + else role + ) + turn = role + content + # this is still the user query, we should + if not content.strip(): + LOG.warning(f"user turn has empty text: {prompt}") + res = self._tokenize( + turn, + add_eos_token=False, + strip_bos_token=True, + ) + if self.train_on_inputs: + labels = copy.deepcopy(res["input_ids"]) + else: + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif output_turn: + role = ( + role.replace(role_remap[1]["from"], role_remap[1]["to"]) + if role_remap + else role + ) + turn = role + content + # this should be the assistant response, should end with an eos token + if not content.strip(): + LOG.warning(f"assistant turn has empty text: {prompt}") + add_eos_token = not ( + conversation.name == "chatml" + and conversation.sep == self.tokenizer.eos_token + ) + res = self._tokenize( + turn, + add_eos_token=add_eos_token, + strip_bos_token=True, + ) + role_res = self._tokenize( + role.rstrip(), + add_eos_token=False, + strip_bos_token=True, + ) + labels = copy.deepcopy(res["input_ids"]) + if not self.train_on_inputs: + # mask out role tokens from the labels + len_role = len(role_res["input_ids"]) + labels[:len_role] = [IGNORE_TOKEN_ID] * min( + len_role, len(labels) + ) + else: + LOG.warning(f"unhandled role: {role}") + continue + + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) + return result + except (KeyError, AssertionError, IndexError) as err: + raise InvalidDataException(str(err)) from err diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 73966def3f..34bfc1d7ae 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -355,6 +355,65 @@ def __init__( ) +CONVERSATION_ROLE_FORMAT = { + "chatml": "<|im_start|>{ROLE}", + "zephyr": "<|{ROLE}|>", +} + + +class ShareGPTPrompterV2MultiRole(ShareGPTPrompterV2): + """ + An multi-role V2 prompter that generates prompts for the ShareGPT that supports multi-role + """ + + def _build_result(self, source): + if len(source) < 2: + # If there isn't a back and forth conversation, ignore it + # also happens on the data splitting leaving empty conversations + raise IndexError( + f"A conversation entry has less than 2 messages :\n{source}" + ) + + conv = self._conversation.copy() + + # Add the conversation system prompt if provided, otherwise use the default one + if source[0]["from"] == "system": + conv.set_system_message(source[0]["value"]) + source.pop(0) + + roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} + + try: + # Apply prompt templates + if source[0]["from"] not in roles: + # Skip the first one if it is not from human + source = source[1:] + except IndexError as err: + # sometimes there is a bing or system chat + raise err + + conv.messages = [] + for _, sentence in enumerate(source): + from_role = sentence["from"] + if from_role in roles: + role = roles[from_role] + else: + if self._conversation.name not in CONVERSATION_ROLE_FORMAT: + raise NotImplementedError( + f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." + ) + + role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( + ROLE=from_role + ) + + if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): + LOG.warning(f"Roles did not alternate: {sentence}") + conv.append_message(role, sentence["value"]) + + return conv.get_turns() + + class UnsupportedPrompter(Prompter): """ A dummy class for custom prompters