Skip to content

Commit

Permalink
feat(prompt): support multiple roles for sharegpt
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jan 17, 2024
1 parent f2d48d2 commit 885f603
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 2 deletions.
156 changes: 154 additions & 2 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import copy
import logging
from typing import Any, Dict, Optional

from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template

from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.prompt_tokenizers import (
InvalidDataException,
ShareGPTPromptTokenizingStrategy,
parse_tokenized_to_result,
tokenize_prompt_default,
)
from axolotl.prompters import (
IGNORE_TOKEN_ID,
ShareGPTPrompterV2,
ShareGPTPrompterV2MultiRole,
)

LOG = logging.getLogger("axolotl")

register_conv_template(
Conversation(
Expand Down Expand Up @@ -74,6 +87,28 @@ def load_guanaco(tokenizer, cfg):
)


def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strategy = MultiRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2MultiRole(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]

return strategy


class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
Expand Down Expand Up @@ -140,3 +175,120 @@ def get_conversation_thread(self, prompt):
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns


class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy for support of multi-role
"""

def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
)
user, assistant = conversation.roles

input_roles = {
"human",
"funcresponse",
"funccaller",
"tool",
"tool_response",
user,
}
output_roles = {"gpt", "tool_caller", assistant}

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]

try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue

role, content = part

# Uses "in" because role contains extra characters
input_turn = any(r in role.lower() for r in input_roles)
output_turn = any(r in role.lower() for r in output_roles)

if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token = not (
conversation.name == "chatml"
and conversation.sep == self.tokenizer.eos_token
)
res = self._tokenize(
turn,
add_eos_token=add_eos_token,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
labels = copy.deepcopy(res["input_ids"])
if not self.train_on_inputs:
# mask out role tokens from the labels
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
59 changes: 59 additions & 0 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,65 @@ def __init__(
)


CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
}


class ShareGPTPrompterV2MultiRole(ShareGPTPrompterV2):
"""
An multi-role V2 prompter that generates prompts for the ShareGPT that supports multi-role
"""

def _build_result(self, source):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)

conv = self._conversation.copy()

# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
source.pop(0)

roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}

try:
# Apply prompt templates
if source[0]["from"] not in roles:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as err:
# sometimes there is a bing or system chat
raise err

conv.messages = []
for _, sentence in enumerate(source):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
)

role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)

if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"Roles did not alternate: {sentence}")
conv.append_message(role, sentence["value"])

return conv.get_turns()


class UnsupportedPrompter(Prompter):
"""
A dummy class for custom prompters
Expand Down

0 comments on commit 885f603

Please sign in to comment.