diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e61d138c41..4846e35840 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -36,7 +36,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import os import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, + cast) import datasets as hf_datasets import huggingface_hub as hf_hub @@ -57,6 +58,35 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: '.downloaded_finetuning')) SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] +PromptResponseDict = Dict[str, str] +ChatFormattedDict = Dict[str, List[Dict[str, str]]] +Example = Union[PromptResponseDict, ChatFormattedDict] +ExampleType = Literal['prompt_response', 'chat'] +TokenizedExample = Dict[str, List[int]] + + +def _get_example_type(example: Example) -> ExampleType: + """Determines the type of the input example. + + Args: + example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair. + + Returns: + ExampleType: The type of the input example, which can be either 'chat' for multi-way chat formatted conversation or 'prompt_response' for instruction-response pair. + + Raises: + KeyError: If the example type is unknown. + """ + if 'messages' in example: + return 'chat' + elif any([ + pr in example + for pr in _ALLOWED_PROMPT_KEYS.union(_ALLOWED_RESPONSE_KEYS) + ]): + return 'prompt_response' + else: + raise KeyError(f'Unknown conversation type {example=}') + def _is_empty_or_nonexistent(dirpath: str) -> bool: """Check if a directory is empty or non-existent. @@ -70,9 +100,70 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 -def _tokenize_formatted_example( - example: Dict[str, Any], - tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]: +def _slice_chat_formatted_example( + example: ChatFormattedDict, + tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: + """Slices the chat example into a formatted prompt and response. + + Args: + example (ChatFormattedDict): The chat example containing the messages. + tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template. + + Returns: + Tuple[str, str]: The prompt and response as separate strings. + + Raises: + ValueError: If the chat example has less than two messages or if the last message is not from the assistant. + KeyError: If a message does not have a role or content. + """ + messages = example['messages'] + + if len(messages) < 2: + raise ValueError( + f'chat example must have at least two messages. {messages=}') + last_message = messages[-1] + if last_message['role'] != 'assistant': + raise ValueError( + f'last message must be from assistant. {last_message=}') + for message in messages: + if 'role' not in message or 'content' not in message: + raise KeyError(f'message must have role and content. {message=}') + + full_conversation = tokenizer.apply_chat_template(messages, tokenize=False) + prompt = tokenizer.apply_chat_template(messages[:-1], + tokenize=False, + add_generation_prompt=True) + if prompt != full_conversation[:len(prompt)]: + raise ValueError( + f'prompt must be the first part of the full conversation. {prompt=}, {full_conversation=}' + ) + response = full_conversation[len(prompt):] + if len(response) == 0: + raise ValueError( + f'chat example must have at least one assistant message. {messages=}' + ) + return prompt, response + + +def _tokenize_chat_formatted_example( + example: ChatFormattedDict, + tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: + """Tokenizes a chat-formatted example using the provided tokenizer. + + Args: + example (ChatFormattedDict): The chat-formatted example to tokenize. + tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization. + + Returns: + TokenizedExample: The tokenized example. + """ + prompt, response = _slice_chat_formatted_example(example, tokenizer) + return tokenizer(text=prompt, text_target=response) + + +def _tokenize_prompt_response_formatted_example( + example: PromptResponseDict, + tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: """Tokenize a formatted example and validate expected keys.""" example_keys = set(example.keys()) prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) @@ -108,6 +199,35 @@ def _tokenize_formatted_example( return tokenizer(text=prompt, text_target=response) +def _tokenize_formatted_example( + example: Example, + tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: + """Tokenizes a formatted example using the provided tokenizer. + + Args: + example (Example): The input example to be tokenized. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenization. + + Returns: + TokenizedExample: The tokenized example. + + Raises: + ValueError: If the example format is unknown. + """ + example_format = _get_example_type(example) + + if example_format == 'chat': + chat_example = cast(ChatFormattedDict, example) + return _tokenize_chat_formatted_example(chat_example, tokenizer) + elif example_format == 'prompt_response': + prompt_response_example: PromptResponseDict = cast( + PromptResponseDict, example) + return _tokenize_prompt_response_formatted_example( + prompt_response_example, tokenizer) + else: + raise ValueError(f'Unknown conversation type {example_format=}') + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 44d0442a87..73e8427505 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -23,11 +23,8 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data import build_dataloader -from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, - _ALLOWED_RESPONSE_KEYS, - DOWNLOADED_FT_DATASETS_DIRPATH, - SUPPORTED_EXTENSIONS, - _tokenize_formatted_example) +from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -249,10 +246,12 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool, break +@pytest.mark.parametrize('use_chat_formatting', [True, False]) @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('allow_pad_trimming', [True, False]) @pytest.mark.parametrize('packing_ratio', [10.0, None, 'auto']) -def test_finetuning_dataloader(decoder_only_format: bool, +def test_finetuning_dataloader(use_chat_formatting: bool, + decoder_only_format: bool, allow_pad_trimming: bool, packing_ratio: Optional[Union[float, Literal['auto']]]): @@ -265,13 +264,21 @@ def test_finetuning_dataloader(decoder_only_format: bool, cfg = { 'name': 'finetuning', 'dataset': { - 'hf_name': 'HuggingFaceH4/databricks_dolly_15k', - 'split': 'train', - 'max_seq_len': max_seq_len, - 'decoder_only_format': decoder_only_format, - 'allow_pad_trimming': allow_pad_trimming, - 'packing_ratio': packing_ratio, - 'shuffle': True, + 'hf_name': + 'iamroot/chat_formatted_examples' if use_chat_formatting else + 'HuggingFaceH4/databricks_dolly_15k', + 'split': + 'train', + 'max_seq_len': + max_seq_len, + 'decoder_only_format': + decoder_only_format, + 'allow_pad_trimming': + allow_pad_trimming, + 'packing_ratio': + packing_ratio, + 'shuffle': + True, }, 'drop_last': False, 'num_workers': 0, @@ -417,39 +424,6 @@ def test_finetuning_dataloader_small_data(dataset_size: int, shutil.rmtree(tiny_dataset_folder_path) -def test_tokenize_example_malformed(): - no_keys = {} - no_prompt_key = {'response': 'response'} - no_response_key = {'prompt': 'prompt'} - extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'} - extra_keys_with_response = {'response': 'response', 'extra': 'extra'} - multiple_allowed_response_keys = { - 'prompt': 'prompt', - 'response': 'response', - 'completion': 'completion' - } - - malformed_examples = [ - no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, - extra_keys_with_response, multiple_allowed_response_keys - ] - - for example in malformed_examples: - with pytest.raises(KeyError): - _tokenize_formatted_example(example, MagicMock()) - - -def test_tokenize_example_well_formed(): - tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') - - for prompt_key in _ALLOWED_PROMPT_KEYS: - for response_key in _ALLOWED_RESPONSE_KEYS: - example = {prompt_key: 'prompt', response_key: 'response'} - tokenized_example = _tokenize_formatted_example(example, tokenizer) - assert 'input_ids' in tokenized_example - assert 'labels' in tokenized_example - - @pytest.mark.parametrize('split', ['train', 'custom', 'data']) def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): tokenizer_name = 'gpt2' diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py new file mode 100644 index 0000000000..258e491b32 --- /dev/null +++ b/tests/data/test_template_tokenization.py @@ -0,0 +1,163 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest +import transformers + +from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, + _ALLOWED_RESPONSE_KEYS, + _slice_chat_formatted_example, + _tokenize_formatted_example) +from llmfoundry.utils.builders import build_tokenizer + + +def test_tokenize_chat_example_malformed(): + no_content = {'messages': [{'role': 'user'}]} + too_few_messages = { + 'messages': [{ + 'role': 'assistant', + 'content': 'Hi, User!' + }] + } + ends_with_user_role = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello GPT!' + }, { + 'role': 'assistant', + 'content': 'Hi, User!' + }, { + 'role': 'user', + 'content': 'user message not followed by an assistant label' + }] + } + no_assistant_message = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello GPT!' + }, { + 'role': 'user', + 'content': 'user message not followed by an assistant label' + }] + } + malformed_chat_examples = [ + too_few_messages, no_content, ends_with_user_role, no_assistant_message + ] + my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) + for example in malformed_chat_examples: + with pytest.raises(Exception): + _tokenize_formatted_example( + example, my_tokenizer + ) # type: ignore (the typing here is supposed to be malformed) + + +def test_tokenize_chat_example_well_formed(): + chat_examples = [ + { + 'messages': [{ + 'role': 'user', + 'content': 'Hello, GPT' + }, { + 'role': 'assistant', + 'content': 'this is my response' + }] + }, # prompt/response but in chat format + { + 'messages': [ + { + 'role': 'user', + 'content': 'Hello, GPT' + }, + { + 'role': 'assistant', + 'content': 'this is my response' + }, + { + 'role': 'user', + 'content': 'Nice to hear that.' + }, + { + 'role': 'assistant', + 'content': 'multi-way chat works too!' + }, + ] + }, # multi-way chat + ] + + expected = [ + { + 'prompt': + '''<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. +<|im_start|>user +Hello, GPT<|im_end|> +<|im_start|>assistant +''', + 'response': + 'this is my response<|im_end|>' + }, + { + 'prompt': + '''<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. +<|im_start|>user +Hello, GPT<|im_end|> +<|im_start|>assistant +this is my response<|im_end|> +<|im_start|>user +Nice to hear that.<|im_end|> +<|im_start|>assistant +''', + 'response': + 'multi-way chat works too!<|im_end|>' + }, + ] + + chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) + assert len(expected) == len( + chat_examples) # if we add a new example, zip shouldn't fail silently + for chat_example, expected_stringification in zip(chat_examples, expected): + prompt, response = _slice_chat_formatted_example( + chat_example, chat_tokenizer) + tokenized_example = _tokenize_formatted_example(chat_example, + chat_tokenizer) + assert prompt == expected_stringification['prompt'] + assert response == expected_stringification['response'] + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example + + +def test_tokenize_instruct_example_malformed(): + no_keys = {} + no_prompt_key = {'response': 'response'} + no_response_key = {'prompt': 'prompt'} + extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'} + extra_keys_with_response = {'response': 'response', 'extra': 'extra'} + multiple_allowed_response_keys = { + 'prompt': 'prompt', + 'response': 'response', + 'completion': 'completion' + } + + malformed_prompt_response_examples = [ + no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, + extra_keys_with_response, multiple_allowed_response_keys + ] + + for example in malformed_prompt_response_examples: + with pytest.raises(KeyError): + _tokenize_formatted_example(example, MagicMock()) + + +def test_tokenize_instruct_example_well_formed(): + tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') + + for prompt_key in _ALLOWED_PROMPT_KEYS: + for response_key in _ALLOWED_RESPONSE_KEYS: + + example = {prompt_key: 'prompt', response_key: 'response'} + tokenized_example = _tokenize_formatted_example(example, tokenizer) + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example