Skip to content

Commit

Permalink
Adds support for chat formatted finetuning input data. (#884)
Browse files Browse the repository at this point in the history
* fix conflicting formatting linting guidelines

* used older union operator for legacy support

* did the same thing in another place

* isort ignore specific lines

* fixes

* isort do not skip line

* address comments

* renamed some more things

* split tests and add some verification for tokenization split

* fix formatting

* added docstrings

* added end-to-end-test with HF dataset

* fix code style

* renamed file and fixed tests

* use chat template diff

* addressed comment

* Update llmfoundry/data/finetuning/tasks.py

Co-authored-by: Daniel King <[email protected]>

* Update llmfoundry/data/finetuning/tasks.py

Co-authored-by: Daniel King <[email protected]>

* fixed type of TokenizedExample

* use cast

* use _ALLOWED_{PROMPT, RESPONSE}_KEYS

* updated tests

* fix

* fix?

* Update llmfoundry/data/finetuning/tasks.py

Co-authored-by: Daniel King <[email protected]>

* Update llmfoundry/data/finetuning/tasks.py

Co-authored-by: Daniel King <[email protected]>

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
milocress and dakinggg authored Jan 26, 2024
1 parent 2634987 commit ac78354
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 50 deletions.
128 changes: 124 additions & 4 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import os
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
Expand All @@ -57,6 +58,35 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']

PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
Example = Union[PromptResponseDict, ChatFormattedDict]
ExampleType = Literal['prompt_response', 'chat']
TokenizedExample = Dict[str, List[int]]


def _get_example_type(example: Example) -> ExampleType:
"""Determines the type of the input example.
Args:
example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair.
Returns:
ExampleType: The type of the input example, which can be either 'chat' for multi-way chat formatted conversation or 'prompt_response' for instruction-response pair.
Raises:
KeyError: If the example type is unknown.
"""
if 'messages' in example:
return 'chat'
elif any([
pr in example
for pr in _ALLOWED_PROMPT_KEYS.union(_ALLOWED_RESPONSE_KEYS)
]):
return 'prompt_response'
else:
raise KeyError(f'Unknown conversation type {example=}')


def _is_empty_or_nonexistent(dirpath: str) -> bool:
"""Check if a directory is empty or non-existent.
Expand All @@ -70,9 +100,70 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
def _slice_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""Slices the chat example into a formatted prompt and response.
Args:
example (ChatFormattedDict): The chat example containing the messages.
tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template.
Returns:
Tuple[str, str]: The prompt and response as separate strings.
Raises:
ValueError: If the chat example has less than two messages or if the last message is not from the assistant.
KeyError: If a message does not have a role or content.
"""
messages = example['messages']

if len(messages) < 2:
raise ValueError(
f'chat example must have at least two messages. {messages=}')
last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
for message in messages:
if 'role' not in message or 'content' not in message:
raise KeyError(f'message must have role and content. {message=}')

full_conversation = tokenizer.apply_chat_template(messages, tokenize=False)
prompt = tokenizer.apply_chat_template(messages[:-1],
tokenize=False,
add_generation_prompt=True)
if prompt != full_conversation[:len(prompt)]:
raise ValueError(
f'prompt must be the first part of the full conversation. {prompt=}, {full_conversation=}'
)
response = full_conversation[len(prompt):]
if len(response) == 0:
raise ValueError(
f'chat example must have at least one assistant message. {messages=}'
)
return prompt, response


def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a chat-formatted example using the provided tokenizer.
Args:
example (ChatFormattedDict): The chat-formatted example to tokenize.
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
Returns:
TokenizedExample: The tokenized example.
"""
prompt, response = _slice_chat_formatted_example(example, tokenizer)
return tokenizer(text=prompt, text_target=response)


def _tokenize_prompt_response_formatted_example(
example: PromptResponseDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
Expand Down Expand Up @@ -108,6 +199,35 @@ def _tokenize_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a formatted example using the provided tokenizer.
Args:
example (Example): The input example to be tokenized.
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenization.
Returns:
TokenizedExample: The tokenized example.
Raises:
ValueError: If the example format is unknown.
"""
example_format = _get_example_type(example)

if example_format == 'chat':
chat_example = cast(ChatFormattedDict, example)
return _tokenize_chat_formatted_example(chat_example, tokenizer)
elif example_format == 'prompt_response':
prompt_response_example: PromptResponseDict = cast(
PromptResponseDict, example)
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise ValueError(f'Unknown conversation type {example_format=}')


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down
66 changes: 20 additions & 46 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
_tokenize_formatted_example)
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand Down Expand Up @@ -249,10 +246,12 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
break


@pytest.mark.parametrize('use_chat_formatting', [True, False])
@pytest.mark.parametrize('decoder_only_format', [True, False])
@pytest.mark.parametrize('allow_pad_trimming', [True, False])
@pytest.mark.parametrize('packing_ratio', [10.0, None, 'auto'])
def test_finetuning_dataloader(decoder_only_format: bool,
def test_finetuning_dataloader(use_chat_formatting: bool,
decoder_only_format: bool,
allow_pad_trimming: bool,
packing_ratio: Optional[Union[float,
Literal['auto']]]):
Expand All @@ -265,13 +264,21 @@ def test_finetuning_dataloader(decoder_only_format: bool,
cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': 'HuggingFaceH4/databricks_dolly_15k',
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': decoder_only_format,
'allow_pad_trimming': allow_pad_trimming,
'packing_ratio': packing_ratio,
'shuffle': True,
'hf_name':
'iamroot/chat_formatted_examples' if use_chat_formatting else
'HuggingFaceH4/databricks_dolly_15k',
'split':
'train',
'max_seq_len':
max_seq_len,
'decoder_only_format':
decoder_only_format,
'allow_pad_trimming':
allow_pad_trimming,
'packing_ratio':
packing_ratio,
'shuffle':
True,
},
'drop_last': False,
'num_workers': 0,
Expand Down Expand Up @@ -417,39 +424,6 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
shutil.rmtree(tiny_dataset_folder_path)


def test_tokenize_example_malformed():
no_keys = {}
no_prompt_key = {'response': 'response'}
no_response_key = {'prompt': 'prompt'}
extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'}
extra_keys_with_response = {'response': 'response', 'extra': 'extra'}
multiple_allowed_response_keys = {
'prompt': 'prompt',
'response': 'response',
'completion': 'completion'
}

malformed_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]

for example in malformed_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())


def test_tokenize_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:
example = {prompt_key: 'prompt', response_key: 'response'}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
tokenizer_name = 'gpt2'
Expand Down
Loading

0 comments on commit ac78354

Please sign in to comment.