Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ShareGPT chat format #1098

Merged
merged 9 commits into from
Apr 11, 2024
43 changes: 31 additions & 12 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
_ALLOWED_MESSAGES_KEYS = {'messages'}
_ALLOWED_ROLE_KEYS = {'role'}
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system'}
_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
Expand Down Expand Up @@ -217,7 +217,7 @@ def slice_out_last_turn(
if conversation_through_previous_turn != prompt_with_history[:len(
conversation_through_previous_turn)]:
raise ValueError(
f'The prompt_with_histry must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}'
f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}'
)
if prompt_with_history != full_conversation[:len(prompt_with_history)]:
raise ValueError(
Expand Down Expand Up @@ -624,9 +624,8 @@ def print_registered_tasks(self) -> None:
log.info('\n'.join(tasks))

def get_preprocessing_fn_from_dict(
self,
mapping: Dict[str,
str]) -> Callable[[Dict[str, Any]], Dict[str, str]]:
self, mapping: Dict[str,
str]) -> Callable[[Dict[str, Any]], Example]:
"""Get a preprocessing function from a dictionary.

The dictionary maps column names in the dataset to "prompt" and "response".
Expand Down Expand Up @@ -662,7 +661,7 @@ def get_preprocessing_fn_from_str(
self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None
) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
) -> Optional[Callable[[Dict[str, Any]], Example]]:
"""Get a preprocessing function from a string.

String can be either a registered function or an import path.
Expand Down Expand Up @@ -710,7 +709,7 @@ def get_preprocessing_fn_from_str(

def build_from_hf(
self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int,
preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]],
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]],
tokenizer: PreTrainedTokenizerBase, target_prompts: str,
target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str,
Any]
Expand Down Expand Up @@ -793,7 +792,8 @@ def build_from_hf(

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return tokenize_formatted_example(preprocessing_fn(example),
tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
Expand Down Expand Up @@ -857,7 +857,7 @@ def build_from_streaming(self, *args: Any,


@dataset_constructor.register('tatsu-lab/alpaca')
def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
def alpaca_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Split out prompt/response from text."""
try:
prompt, response = inp['text'].split('### Response:')
Expand All @@ -869,7 +869,7 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
def dolly_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Format the text string."""
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
try:
Expand All @@ -885,7 +885,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('bigscience/P3')
def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
def p3_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Format the already-split example."""
return {
'prompt': inp['inputs'] + ':',
Expand All @@ -895,7 +895,7 @@ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:

# Muennighoff's P3 and flan datasets share a similar convention
@dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict:
"""Format the already-split example."""
try:
prompt: str = inp['inputs']
Expand All @@ -908,3 +908,22 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'prompt': prompt, 'response': response}


@dataset_constructor.register('teknium/OpenHermes-2.5')
def shareGPT_format_preprocessor(inp: Dict) -> ChatFormattedDict:
"""Convert from ShareGPT format to our chat format."""
role_map = {
'human': 'user',
'gpt': 'assistant',
}
try:
conversation = inp['conversations']
messages: List[Dict[str, str]] = []
for message in conversation:
role: str = role_map.get(message['from'], message['from'])
content: str = message['value']
messages.append({'role': role, 'content': content})
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}
15 changes: 9 additions & 6 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def main(args: Namespace) -> None:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
formatted_sample = preprocessing_fn(sample)
assert isinstance(formatted_sample, dict)

# Use the _get_example_type utility to confirm that the formatted sample
# can be interpreted by the tokenization code
Expand Down Expand Up @@ -300,13 +301,15 @@ def main(args: Namespace) -> None:
out.write(sample_to_write)
else:
if example_type == 'prompt_response':
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in ['prompt', 'response']
}
encoded_sample = {}
for key in ['prompt', 'response']:
value = formatted_sample[key]
assert isinstance(value, str)
encoded_sample[key] = value.encode('utf-8')
out.write(encoded_sample)
else:
encoded_sample = formatted_sample
out.write(encoded_sample)
out.write(formatted_sample)

if tokenizer is not None and examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
Expand Down
84 changes: 84 additions & 0 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,3 +1093,87 @@ def test_build_unknown_dataloader():
tokenizer = MagicMock()
with pytest.raises(catalogue.RegistryError):
_ = build_dataloader(cfg, tokenizer, 2)


invalid_conversation_params_sharegpt = [
'add_invalid_last_chat_message', 'add_invalid_content_type',
'add_invalid_role', 'add_not_alternating_roles'
]


@pytest.mark.parametrize(
','.join(invalid_conversation_params_sharegpt),
generate_exclusive_test_params(invalid_conversation_params_sharegpt))
def test_sharegpt_format(tmp_path: pathlib.Path,
add_invalid_last_chat_message: bool,
add_invalid_content_type: bool, add_invalid_role: bool,
add_not_alternating_roles: bool):
tokenizer_name = 'mosaicml/mpt-7b'
max_seq_len = 2048
dataset_size = 5
device_batch_size = 5
tiny_dataset_folder_path = tmp_path
tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl')

tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_kwargs={'model_max_length': max_seq_len},
)
tokenizer.add_special_tokens({
'pad_token': '<pad>',
'bos_token': '<bos>',
'eos_token': '<eos>',
})

if dist.get_global_rank() == 0:
make_tiny_conversation_ft_dataset(
path=tiny_dataset_path,
size=dataset_size,
add_invalid_last_chat_message=add_invalid_last_chat_message,
add_invalid_message_key_quantity=False,
add_invalid_content_type=add_invalid_content_type,
add_invalid_role=add_invalid_role,
add_not_alternating_roles=add_not_alternating_roles,
use_messages_format=False,
)

cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': str(tiny_dataset_folder_path),
'preprocessing_fn': 'teknium/OpenHermes-2.5',
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 0,
'prefetch_factor': None,
'pin_memory': False,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

error_context = contextlib.nullcontext()
if add_invalid_last_chat_message:
error_context = pytest.raises(InvalidLastChatMessageRoleError,
match='Invalid last message role:')
if add_invalid_content_type:
error_context = pytest.raises(InvalidContentTypeError,
match='Expected content to be')
if add_invalid_role:
error_context = pytest.raises(InvalidRoleError,
match='Expected role to be one of')

if add_not_alternating_roles:
error_context = pytest.raises(ConsecutiveRepeatedChatRolesError,
match='Conversation roles must alternate')

with error_context:
build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader
82 changes: 58 additions & 24 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
_slice_chat_formatted_example,
dataset_constructor,
tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer

Expand Down Expand Up @@ -178,34 +179,67 @@ def test_tokenize_instruct_example_well_formed():
@pytest.mark.parametrize(
'tokenizer_name',
['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'])
def test_multi_turn_chat_slicing(tokenizer_name: str):
convo = [
{
'role': 'system',
'content': 'everyone thinks you are so cool'
},
{
'role': 'user',
'content': 'hiiii'
},
{
'role': 'assistant',
'content': 'yassss'
},
{
'role': 'user',
'content': 'HIIIIII!!!'
},
{
'role': 'assistant',
'content': 'YASSSSSS'
},
]
@pytest.mark.parametrize('messages_format', [True, False])
def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
if messages_format:
convo = [
{
'role': 'system',
'content': 'everyone thinks you are so cool'
},
{
'role': 'user',
'content': 'hiiii'
},
{
'role': 'assistant',
'content': 'yassss'
},
{
'role': 'user',
'content': 'HIIIIII!!!'
},
{
'role': 'assistant',
'content': 'YASSSSSS'
},
]
else:
convo = [
{
'from': 'system',
'value': 'everyone thinks you are so cool'
},
{
'from': 'human',
'value': 'hiiii'
},
{
'from': 'gpt',
'value': 'yassss'
},
{
'from': 'tool',
'value': 'HIIIIII!!!'
},
{
'from': 'gpt',
'value': 'YASSSSSS'
},
]
tmp = {'conversations': convo}
preprocessor = dataset_constructor.get_preprocessing_fn_from_str(
'teknium/OpenHermes-2.5')
assert preprocessor is not None
convo = preprocessor(tmp)['messages']
assert isinstance(convo, list)

example = {'messages': convo}

tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name)

templated_prompt_response_turns = _slice_chat_formatted_example(
{'messages': convo}, tok)
example, tok)

reconstructed_chat = ''
for prompt, response in templated_prompt_response_turns:
Expand Down
21 changes: 20 additions & 1 deletion tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
from argparse import Namespace
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional

from omegaconf import DictConfig
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -99,6 +99,7 @@ def make_tiny_conversation_ft_dataset(
add_invalid_content_type: bool = False,
add_invalid_role: bool = False,
add_not_alternating_roles: bool = False,
use_messages_format: bool = True,
):
if Path(path).suffix != '.jsonl':
raise ValueError(f'Path {path} must be a jsonl file.')
Expand Down Expand Up @@ -198,6 +199,24 @@ def make_tiny_conversation_ft_dataset(
}]
})

def messages_to_conversation(sample: Dict):
assert 'messages' in sample
messages = sample['messages']

role_map = {
'user': 'human',
'assistant': 'gpt',
}
conversations: List[Dict[str, str]] = []
for message in messages:
role: str = role_map.get(message['role'], message['role'])
content: str = message['content']
conversations.append({'from': role, 'value': content})
return {'conversations': conversations}

if not use_messages_format:
samples = [messages_to_conversation(sample) for sample in samples]

os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
for sample in samples:
Expand Down
Loading