diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 60d610da9b..d6ed3eda3e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -65,8 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: def _get_example_type(example: Example) -> ExampleType: - """ - Determines the type of the input example. + """Determines the type of the input example. Args: example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair. @@ -76,7 +75,6 @@ def _get_example_type(example: Example) -> ExampleType: Raises: KeyError: If the example type is unknown. - """ if 'messages' in example: return 'chat' @@ -101,20 +99,20 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _slice_chat_formatted_example( example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: - """ - Applies the tokenizer's chat template to the example messages and slices the resulting templated string into a prompt and a completion. - + """Slices the chat example into a formatted prompt and response. + Args: example (ChatFormattedDict): The chat example containing the messages. tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template. - + Returns: Tuple[str, str]: The prompt and response as separate strings. - + Raises: ValueError: If the chat example has less than two messages or if the last message is not from the assistant. KeyError: If a message does not have a role or content. """ + def slice(s: str, sep: str): # it seems like we can reuse this logic, as we likely have this pattern in other places. slices = s.split(sep) @@ -144,8 +142,7 @@ def slice(s: str, sep: str): def _tokenize_chat_formatted_example( example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: - """ - Tokenizes a chat-formatted example using the provided tokenizer. + """Tokenizes a chat-formatted example using the provided tokenizer. Args: example (ChatFormattedDict): The chat-formatted example to tokenize. @@ -153,7 +150,6 @@ def _tokenize_chat_formatted_example( Returns: TokenizedExample: The tokenized example. - """ prompt, response = _slice_chat_formatted_example(example, tokenizer) return tokenizer(text=prompt, text_target=response) @@ -200,8 +196,7 @@ def _tokenize_prompt_response_formatted_example( def _tokenize_formatted_example( example: Example, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: - """ - Tokenizes a formatted example using the provided tokenizer. + """Tokenizes a formatted example using the provided tokenizer. Args: example (Example): The input example to be tokenized. diff --git a/tests/data/test_chat_tokenization.py b/tests/data/test_chat_tokenization.py new file mode 100644 index 0000000000..09325dce21 --- /dev/null +++ b/tests/data/test_chat_tokenization.py @@ -0,0 +1,139 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest +import transformers + +from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, + _ALLOWED_RESPONSE_KEYS, + _slice_chat_formatted_example, + _tokenize_formatted_example) +from llmfoundry.utils.builders import build_tokenizer + + +def test_tokenize_chat_example_malformed(): + no_content = {'messages': [{'role': 'user'}]} + ends_with_user_role = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello GPT!' + }, { + 'role': 'assistant', + 'content': 'Hi, User!' + }, { + 'role': 'user', + 'content': 'user message not followed by an assistant label' + }] + } + no_assistant_message = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello GPT!' + }, { + 'role': 'user', + 'content': 'user message not followed by an assistant label' + }] + } + malformed_chat_examples = [ + no_content, ends_with_user_role, no_assistant_message + ] + my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) + for example in malformed_chat_examples: + with pytest.raises(Exception): + _tokenize_formatted_example( + example, my_tokenizer + ) # type: ignore (the typing here is supposed to be malformed) + + +def test_tokenize_chat_example_well_formed(): + chat_examples = [ + { + 'messages': [{ + 'role': 'user', + 'content': 'Hello, GPT' + }, { + 'role': 'assistant', + 'content': 'this is my response' + }] + }, # prompt/response but in chat format + { + 'messages': [ + { + 'role': 'user', + 'content': 'Hello, GPT' + }, + { + 'role': 'assistant', + 'content': 'this is my response' + }, + { + 'role': 'user', + 'content': 'Nice to hear that.' + }, + { + 'role': 'assistant', + 'content': 'multi-way chat works too!' + }, + ] + }, # multi-way chat + ] + + expected = [ + { + 'prompt': 'TODO: fix', + 'response': 'TODO: fix' + }, + { + 'prompt': 'TODO: fix', + 'response': 'TODO: fix' + }, + ] + + chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) + assert len(expected) == len( + chat_examples) # if we add a new example, zip shouldn't fail silently + for chat_example, expected_stringification in zip(chat_examples, expected): + prompt, response = _slice_chat_formatted_example( + chat_example, chat_tokenizer) + tokenized_example = _tokenize_formatted_example(chat_example, + chat_tokenizer) + assert prompt == expected_stringification['prompt'] + assert response == expected_stringification['response'] + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example + + +def test_tokenize_instruct_example_malformed(): + no_keys = {} + no_prompt_key = {'response': 'response'} + no_response_key = {'prompt': 'prompt'} + extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'} + extra_keys_with_response = {'response': 'response', 'extra': 'extra'} + multiple_allowed_response_keys = { + 'prompt': 'prompt', + 'response': 'response', + 'completion': 'completion' + } + + malformed_prompt_response_examples = [ + no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, + extra_keys_with_response, multiple_allowed_response_keys + ] + + for example in malformed_prompt_response_examples: + with pytest.raises(KeyError): + _tokenize_formatted_example(example, MagicMock()) + + +def test_tokenize_instruct_example_well_formed(): + tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') + + for prompt_key in _ALLOWED_PROMPT_KEYS: + for response_key in _ALLOWED_RESPONSE_KEYS: + + example = {prompt_key: 'prompt', response_key: 'response'} + tokenized_example = _tokenize_formatted_example(example, tokenizer) + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 2edbe149a9..e4f26f70f7 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -23,12 +23,8 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data import build_dataloader -from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, - _ALLOWED_RESPONSE_KEYS, - DOWNLOADED_FT_DATASETS_DIRPATH, - SUPPORTED_EXTENSIONS, - _slice_chat_formatted_example, - _tokenize_formatted_example) +from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -428,124 +424,6 @@ def test_finetuning_dataloader_small_data(dataset_size: int, shutil.rmtree(tiny_dataset_folder_path) -def test_tokenize_instruct_example_malformed(): - no_keys = {} - no_prompt_key = {'response': 'response'} - no_response_key = {'prompt': 'prompt'} - extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'} - extra_keys_with_response = {'response': 'response', 'extra': 'extra'} - multiple_allowed_response_keys = { - 'prompt': 'prompt', - 'response': 'response', - 'completion': 'completion' - } - - malformed_prompt_response_examples = [ - no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, - extra_keys_with_response, multiple_allowed_response_keys - ] - - for example in malformed_prompt_response_examples: - with pytest.raises(KeyError): - _tokenize_formatted_example(example, MagicMock()) - - -def test_tokenize_chat_example_malformed(): - no_content = {'messages': [{'role': 'user'}]} - ends_with_user_role = { - 'messages': [{ - 'role': 'user', - 'content': 'Hello GPT!' - }, { - 'role': 'assistant', - 'content': 'Hi, User!' - }, { - 'role': 'user', - 'content': 'user message not followed by an assistant label' - }] - } - no_assistant_message = { - 'messages': [{ - 'role': 'user', - 'content': 'Hello GPT!' - }, { - 'role': 'user', - 'content': 'user message not followed by an assistant label' - }] - } - malformed_chat_examples = [ - no_content, ends_with_user_role, no_assistant_message - ] - my_tokenizer = build_tokenizer('mosaicml/mpt-7b-chat', {}) - for example in malformed_chat_examples: - with pytest.raises(Exception): - _tokenize_formatted_example( - example, my_tokenizer - ) # type: ignore (the typing here is supposed to be malformed) - - -def test_tokenize_instruct_example_well_formed(): - tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') - - for prompt_key in _ALLOWED_PROMPT_KEYS: - for response_key in _ALLOWED_RESPONSE_KEYS: - - example = {prompt_key: 'prompt', response_key: 'response'} - tokenized_example = _tokenize_formatted_example(example, tokenizer) - assert 'input_ids' in tokenized_example - assert 'labels' in tokenized_example - - -def test_tokenize_chat_example_well_formed(): - chat_examples = [ - { - 'messages': [{ - 'role': 'user', - 'content': 'Hello, GPT' - }, { - 'role': 'assistant', - 'content': 'this is my response' - }] - }, # prompt/response but in chat format - { - 'messages': [ - { - 'role': 'user', - 'content': 'Hello, GPT' - }, - { - 'role': 'assistant', - 'content': 'this is my response' - }, - { - 'role': 'user', - 'content': 'Nice to hear that.' - }, - { - 'role': 'assistant', - 'content': 'multi-way chat works too!' - }, - ] - }, # multi-way chat - ] - - chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-chat', {}) - for chat_example in chat_examples: - last_message = chat_example['messages'][-1]['content'] - earlier_messages = [ - msg['content'] for msg in chat_example['messages'][:-1] - ] - prompt, response = _slice_chat_formatted_example( - chat_example, chat_tokenizer) - tokenized_example = _tokenize_formatted_example(chat_example, - chat_tokenizer) - assert last_message in response - for earlier_message in earlier_messages: - assert earlier_message in prompt - assert 'input_ids' in tokenized_example - assert 'labels' in tokenized_example - - @pytest.mark.parametrize('split', ['train', 'custom', 'data']) def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): tokenizer_name = 'gpt2'