Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add sharegpt multirole #1137

Merged
merged 16 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].

# Custom user instruction prompt
- path: repo
Expand Down
12 changes: 11 additions & 1 deletion src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""

import logging
from typing import Any, Dict, Optional

from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
Expand All @@ -11,6 +12,8 @@
merge_consecutive_messages,
)

LOG = logging.getLogger("axolotl")


def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
Expand Down Expand Up @@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
Expand Down Expand Up @@ -142,7 +147,12 @@ def get_conversation_thread(self, prompt):
"system": "system",
}
turns = [
{"from": role_map[t[role_key]], "value": t[value_key]}
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
}
for t in conversations
]
return turns
Expand Down
49 changes: 30 additions & 19 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter

LOG = logging.getLogger("axolotl")

Expand All @@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):

def __init__(
self,
prompter,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
Expand Down Expand Up @@ -340,6 +340,23 @@ def tokenize_prompt(self, prompt):
self.prompter._conversation.copy() # pylint: disable=protected-access
)

input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}

if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)

# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)

if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
Expand All @@ -360,19 +377,18 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue

tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part

# Uses "in" because role contains extra characters
if user_role_label in role:
input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""

if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue

if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
Expand All @@ -392,7 +408,7 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role:
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
Expand Down Expand Up @@ -423,7 +439,7 @@ def tokenize_prompt(self, prompt):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
Expand All @@ -434,11 +450,6 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Expand Down
35 changes: 30 additions & 5 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ def __repr__(self) -> str:
"Role did not alternate between turns (gpt and human). Please check your data."
)

CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
}


class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
Expand All @@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None

def __init__(
self,
Expand All @@ -277,6 +285,7 @@ def __init__(
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
if conversation:
if isinstance(conversation, Conversation):
Expand All @@ -291,6 +300,8 @@ def __init__(
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
if roles:
self.roles = roles

def _build_result(self, source):
if len(source) < 2:
Expand Down Expand Up @@ -322,11 +333,23 @@ def _build_result(self, source):

conv.messages = []
for _, sentence in enumerate(source):
role = roles[sentence["from"]]
if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)

role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)

if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")

conv.append_message(role, sentence["value"])

return conv.get_turns()
Expand Down Expand Up @@ -354,11 +377,13 @@ def __init__(
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
roles=roles,
)


Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None

roles: Optional[Dict[str, List[str]]] = None


class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
Expand Down
68 changes: 68 additions & 0 deletions tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
)


@pytest.fixture(name="multi_role_dataset")
def fixture_multi_role_dataset():
return Dataset.from_list(
[
{
"conversations": [
{
"from": "system",
"value": "use get_weather(city) to get the weather for a city",
},
{
"from": "human",
"value": "hello, what's the weather in New York?",
},
{
"from": "gpt",
"value": "let me get that for you",
},
{
"from": "tool",
"value": "get_weather(New York)",
},
{
"from": "gpt",
"value": "the weather in New York is 70 degrees and sunny",
},
]
}
]
)


@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
Expand Down Expand Up @@ -196,3 +228,39 @@ def test_chatml_glaive(self, glaive_dataset, tokenizer):
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
]
# fmt: on

def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, multi_role_dataset, process_count=1
)

input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
1, # bos
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on

labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on
Loading