Skip to content

Commit

Permalink
Feat: Add sharegpt multirole (#1137)
Browse files Browse the repository at this point in the history
* feat(prompt): support multiple roles for sharegpt

* fix: add handling of empty role back

* feat: rebased and allowed more dynamic roles via config

* fix: variable

* chore: update message

* feat: add vicuna format

* fix: JSON serializable error

* fix: typing

* fix: don't remap for unknown keys

* fix: add roles to pydantic

* feat: add test

* chore: remove leftover print

* chore: remove leftover comment

* chore: remove print

* fix: update test to use chatml
  • Loading branch information
NanoCode012 authored Mar 19, 2024
1 parent 16a6049 commit 0e03717
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 26 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].

# Custom user instruction prompt
- path: repo
Expand Down
12 changes: 11 additions & 1 deletion src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""

import logging
from typing import Any, Dict, Optional

from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
Expand All @@ -11,6 +12,8 @@
merge_consecutive_messages,
)

LOG = logging.getLogger("axolotl")


def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
Expand Down Expand Up @@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
Expand Down Expand Up @@ -142,7 +147,12 @@ def get_conversation_thread(self, prompt):
"system": "system",
}
turns = [
{"from": role_map[t[role_key]], "value": t[value_key]}
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
}
for t in conversations
]
return turns
Expand Down
49 changes: 30 additions & 19 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter

LOG = logging.getLogger("axolotl")

Expand All @@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):

def __init__(
self,
prompter,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
Expand Down Expand Up @@ -340,6 +340,23 @@ def tokenize_prompt(self, prompt):
self.prompter._conversation.copy() # pylint: disable=protected-access
)

input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}

if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)

# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)

if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
Expand All @@ -360,19 +377,18 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue

tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part

# Uses "in" because role contains extra characters
if user_role_label in role:
input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""

if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue

if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
Expand All @@ -392,7 +408,7 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role:
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
Expand Down Expand Up @@ -423,7 +439,7 @@ def tokenize_prompt(self, prompt):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
Expand All @@ -434,11 +450,6 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Expand Down
35 changes: 30 additions & 5 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ def __repr__(self) -> str:
"Role did not alternate between turns (gpt and human). Please check your data."
)

CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
}


class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
Expand All @@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None

def __init__(
self,
Expand All @@ -277,6 +285,7 @@ def __init__(
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
if conversation:
if isinstance(conversation, Conversation):
Expand All @@ -291,6 +300,8 @@ def __init__(
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
if roles:
self.roles = roles

def _build_result(self, source):
if len(source) < 2:
Expand Down Expand Up @@ -322,11 +333,23 @@ def _build_result(self, source):

conv.messages = []
for _, sentence in enumerate(source):
role = roles[sentence["from"]]
if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)

role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)

if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")

conv.append_message(role, sentence["value"])

return conv.get_turns()
Expand Down Expand Up @@ -354,11 +377,13 @@ def __init__(
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
roles=roles,
)


Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None

roles: Optional[Dict[str, List[str]]] = None


class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
Expand Down
68 changes: 68 additions & 0 deletions tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
)


@pytest.fixture(name="multi_role_dataset")
def fixture_multi_role_dataset():
return Dataset.from_list(
[
{
"conversations": [
{
"from": "system",
"value": "use get_weather(city) to get the weather for a city",
},
{
"from": "human",
"value": "hello, what's the weather in New York?",
},
{
"from": "gpt",
"value": "let me get that for you",
},
{
"from": "tool",
"value": "get_weather(New York)",
},
{
"from": "gpt",
"value": "the weather in New York is 70 degrees and sunny",
},
]
}
]
)


@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
Expand Down Expand Up @@ -196,3 +228,39 @@ def test_chatml_glaive(self, glaive_dataset, tokenizer):
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
]
# fmt: on

def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, multi_role_dataset, process_count=1
)

input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
1, # bos
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on

labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on

0 comments on commit 0e03717

Please sign in to comment.