diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index a29dee7683..4b80ffef54 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -47,25 +47,46 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: __all__ = ['dataset_constructor'] +_ALLOWED_RESPONSE_KEYS = {'response', 'completion'} +_ALLOWED_PROMPT_KEYS = {'prompt'} + def _tokenize_formatted_example( example: Dict[str, Any], tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]: - if ('prompt' not in example) or ('response' not in example): + """Tokenize a formatted example and validate expected keys.""" + example_keys = set(example.keys()) + prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) + response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS) + + if len(prompt_keys) != 1: + raise KeyError( + f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\ + f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}' + ) + + if len(response_keys) != 1: raise KeyError( - 'Unable to tokenize example because it has not been properly formatted. ' +\ - '"prompt" and "response" are required keys but at least one was missing ' +\ - f'from {example=}.' + f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\ + f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}' ) - if not isinstance(example['prompt'], str): + + prompt_key = prompt_keys.pop() + response_key = response_keys.pop() + prompt = example[prompt_key] + response = example[response_key] + + if not isinstance(prompt, str): raise TypeError( - f'Unable to tokenize example because "prompt" was not a string. {example=}' + f'Unable to tokenize example because {prompt_key} was not a string. {example=}' ) - if not isinstance(example['response'], str): + + if not isinstance(response, str): raise TypeError( - f'Unable to tokenize example because "response" was not a string. {example=}' + f'Unable to tokenize example because {response_key} was not a string. {example=}' ) - return tokenizer(text=example['prompt'], text_target=example['response']) + + return tokenizer(text=prompt, text_target=response) class StreamingFinetuningDataset(StreamingDataset): diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 0f5f506e22..747021e82a 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -21,6 +21,9 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data import build_dataloader +from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, + _ALLOWED_RESPONSE_KEYS, + _tokenize_formatted_example) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -355,10 +358,8 @@ def test_finetuning_dataloader_small_data(dataset_size: int, if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last: error_context = pytest.raises(ValueError, match='Your dataset') if invalid_dataset: - error_context = pytest.raises( - TypeError, - match='Unable to tokenize example because "prompt" was not a string' - ) + error_context = pytest.raises(TypeError, + match='Unable to tokenize example') with error_context: _ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) @@ -367,6 +368,39 @@ def test_finetuning_dataloader_small_data(dataset_size: int, shutil.rmtree(tiny_dataset_folder_path) +def test_tokenize_example_malformed(): + no_keys = {} + no_prompt_key = {'response': 'response'} + no_response_key = {'prompt': 'prompt'} + extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'} + extra_keys_with_response = {'response': 'response', 'extra': 'extra'} + multiple_allowed_response_keys = { + 'prompt': 'prompt', + 'response': 'response', + 'completion': 'completion' + } + + malformed_examples = [ + no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, + extra_keys_with_response, multiple_allowed_response_keys + ] + + for example in malformed_examples: + with pytest.raises(KeyError): + _tokenize_formatted_example(example, MagicMock()) + + +def test_tokenize_example_well_formed(): + tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') + + for prompt_key in _ALLOWED_PROMPT_KEYS: + for response_key in _ALLOWED_RESPONSE_KEYS: + example = {prompt_key: 'prompt', response_key: 'response'} + tokenized_example = _tokenize_formatted_example(example, tokenizer) + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example + + @pytest.mark.parametrize('split', ['train', 'custom', 'data']) def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): tokenizer_name = 'gpt2'