Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for chat formatted finetuning input data. #884

Merged
merged 28 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5f5a144
fix conflicting formatting linting guidelines
milocress Jan 18, 2024
3ef2b20
used older union operator for legacy support
milocress Jan 18, 2024
21aa9e3
did the same thing in another place
milocress Jan 18, 2024
2148405
isort ignore specific lines
milocress Jan 18, 2024
b660dbc
fixes
milocress Jan 18, 2024
fda8ab3
isort do not skip line
milocress Jan 18, 2024
ef7c78e
address comments
milocress Jan 19, 2024
e0bd660
renamed some more things
milocress Jan 19, 2024
6e73e04
split tests and add some verification for tokenization split
milocress Jan 19, 2024
b7b1d5f
fix formatting
milocress Jan 19, 2024
7764d7c
added docstrings
milocress Jan 19, 2024
224f724
added end-to-end-test with HF dataset
milocress Jan 19, 2024
ca8c02d
fix code style
milocress Jan 22, 2024
cf6664b
renamed file and fixed tests
milocress Jan 22, 2024
31197e7
use chat template diff
milocress Jan 22, 2024
b8bac98
addressed comment
milocress Jan 22, 2024
ef66300
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 24, 2024
5dd4f60
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 24, 2024
e44ceef
fixed type of TokenizedExample
milocress Jan 25, 2024
eca4821
use cast
milocress Jan 25, 2024
4ae1e15
merged
milocress Jan 25, 2024
4a16772
use _ALLOWED_{PROMPT, RESPONSE}_KEYS
milocress Jan 25, 2024
648e4f8
updated tests
milocress Jan 25, 2024
bb54117
fix
milocress Jan 25, 2024
8938d37
fix?
milocress Jan 25, 2024
a7d369d
Merge branch 'main' into milo/use-chat-tokenizers
milocress Jan 25, 2024
e73f0d0
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 25, 2024
4962790
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 124 additions & 4 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import os
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
Expand All @@ -57,6 +58,35 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']

PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
Example = Union[PromptResponseDict, ChatFormattedDict]
ExampleType = Literal['prompt_response', 'chat']
TokenizedExample = Dict[str, List[int]]


def _get_example_type(example: Example) -> ExampleType:
"""Determines the type of the input example.

Args:
example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair.

Returns:
ExampleType: The type of the input example, which can be either 'chat' for multi-way chat formatted conversation or 'prompt_response' for instruction-response pair.

Raises:
KeyError: If the example type is unknown.
"""
if 'messages' in example:
return 'chat'
elif any([
pr in example
for pr in list(_ALLOWED_PROMPT_KEYS) + list(_ALLOWED_RESPONSE_KEYS)
milocress marked this conversation as resolved.
Show resolved Hide resolved
]):
return 'prompt_response'
else:
raise KeyError(f'Unknown conversation type {example=}')


def _is_empty_or_nonexistent(dirpath: str) -> bool:
"""Check if a directory is empty or non-existent.
Expand All @@ -70,9 +100,70 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
def _slice_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""Slices the chat example into a formatted prompt and response.

Args:
example (ChatFormattedDict): The chat example containing the messages.
tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template.

Returns:
Tuple[str, str]: The prompt and response as separate strings.

Raises:
ValueError: If the chat example has less than two messages or if the last message is not from the assistant.
KeyError: If a message does not have a role or content.
"""
messages = example['messages']

if len(messages) < 2:
raise ValueError(
f'chat example must have at least two messages. {messages=}')
last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
for message in messages:
if 'role' not in message or 'content' not in message:
raise KeyError(f'message must have role and content. {message=}')

full_conversation = tokenizer.apply_chat_template(messages, tokenize=False)
prompt = tokenizer.apply_chat_template(messages[:-1],
milocress marked this conversation as resolved.
Show resolved Hide resolved
tokenize=False,
add_generation_prompt=True)
if prompt != full_conversation[:len(prompt)]:
milocress marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f'prompt must be the first part of the full conversation. {prompt=}, {full_conversation=}'
)
response = full_conversation[len(prompt):]
if len(response) == 0:
milocress marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f'chat example must have at least one assistant message. {messages=}'
)
return prompt, response


def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a chat-formatted example using the provided tokenizer.

Args:
example (ChatFormattedDict): The chat-formatted example to tokenize.
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.

Returns:
TokenizedExample: The tokenized example.
"""
prompt, response = _slice_chat_formatted_example(example, tokenizer)
return tokenizer(text=prompt, text_target=response)


def _tokenize_prompt_response_formatted_example(
example: PromptResponseDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
Expand Down Expand Up @@ -108,6 +199,35 @@ def _tokenize_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a formatted example using the provided tokenizer.

Args:
example (Example): The input example to be tokenized.
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenization.

Returns:
TokenizedExample: The tokenized example.

Raises:
ValueError: If the example format is unknown.
"""
example_format = _get_example_type(example)

if example_format == 'chat':
chat_example: ChatFormattedDict = cast(ChatFormattedDict, example)
milocress marked this conversation as resolved.
Show resolved Hide resolved
return _tokenize_chat_formatted_example(chat_example, tokenizer)
elif example_format == 'prompt_response':
prompt_response_example: PromptResponseDict = cast(
PromptResponseDict, example)
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise ValueError(f'Unknown conversation type {example_format=}')


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Expand Down
66 changes: 20 additions & 46 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
_tokenize_formatted_example)
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand Down Expand Up @@ -249,10 +246,12 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool,
break


@pytest.mark.parametrize('use_chat_formatting', [True, False])
@pytest.mark.parametrize('decoder_only_format', [True, False])
@pytest.mark.parametrize('allow_pad_trimming', [True, False])
@pytest.mark.parametrize('packing_ratio', [10.0, None, 'auto'])
def test_finetuning_dataloader(decoder_only_format: bool,
def test_finetuning_dataloader(use_chat_formatting: bool,
decoder_only_format: bool,
allow_pad_trimming: bool,
packing_ratio: Optional[Union[float,
Literal['auto']]]):
Expand All @@ -265,13 +264,21 @@ def test_finetuning_dataloader(decoder_only_format: bool,
cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': 'HuggingFaceH4/databricks_dolly_15k',
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': decoder_only_format,
'allow_pad_trimming': allow_pad_trimming,
'packing_ratio': packing_ratio,
'shuffle': True,
'hf_name':
'iamroot/chat_formatted_examples' if use_chat_formatting else
milocress marked this conversation as resolved.
Show resolved Hide resolved
milocress marked this conversation as resolved.
Show resolved Hide resolved
'HuggingFaceH4/databricks_dolly_15k',
'split':
'train',
'max_seq_len':
max_seq_len,
'decoder_only_format':
decoder_only_format,
'allow_pad_trimming':
allow_pad_trimming,
'packing_ratio':
packing_ratio,
'shuffle':
True,
},
'drop_last': False,
'num_workers': 0,
Expand Down Expand Up @@ -417,39 +424,6 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
shutil.rmtree(tiny_dataset_folder_path)


def test_tokenize_example_malformed():
no_keys = {}
no_prompt_key = {'response': 'response'}
no_response_key = {'prompt': 'prompt'}
extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'}
extra_keys_with_response = {'response': 'response', 'extra': 'extra'}
multiple_allowed_response_keys = {
'prompt': 'prompt',
'response': 'response',
'completion': 'completion'
}

malformed_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]

for example in malformed_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())


def test_tokenize_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:
example = {prompt_key: 'prompt', response_key: 'response'}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
tokenizer_name = 'gpt2'
Expand Down
Loading
Loading