Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for chat formatted finetuning input data. #884

Merged
merged 28 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5f5a144
fix conflicting formatting linting guidelines
milocress Jan 18, 2024
3ef2b20
used older union operator for legacy support
milocress Jan 18, 2024
21aa9e3
did the same thing in another place
milocress Jan 18, 2024
2148405
isort ignore specific lines
milocress Jan 18, 2024
b660dbc
fixes
milocress Jan 18, 2024
fda8ab3
isort do not skip line
milocress Jan 18, 2024
ef7c78e
address comments
milocress Jan 19, 2024
e0bd660
renamed some more things
milocress Jan 19, 2024
6e73e04
split tests and add some verification for tokenization split
milocress Jan 19, 2024
b7b1d5f
fix formatting
milocress Jan 19, 2024
7764d7c
added docstrings
milocress Jan 19, 2024
224f724
added end-to-end-test with HF dataset
milocress Jan 19, 2024
ca8c02d
fix code style
milocress Jan 22, 2024
cf6664b
renamed file and fixed tests
milocress Jan 22, 2024
31197e7
use chat template diff
milocress Jan 22, 2024
b8bac98
addressed comment
milocress Jan 22, 2024
ef66300
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 24, 2024
5dd4f60
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 24, 2024
e44ceef
fixed type of TokenizedExample
milocress Jan 25, 2024
eca4821
use cast
milocress Jan 25, 2024
4ae1e15
merged
milocress Jan 25, 2024
4a16772
use _ALLOWED_{PROMPT, RESPONSE}_KEYS
milocress Jan 25, 2024
648e4f8
updated tests
milocress Jan 25, 2024
bb54117
fix
milocress Jan 25, 2024
8938d37
fix?
milocress Jan 25, 2024
a7d369d
Merge branch 'main' into milo/use-chat-tokenizers
milocress Jan 25, 2024
e73f0d0
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 25, 2024
4962790
Update llmfoundry/data/finetuning/tasks.py
milocress Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
# isort:skip_file

"""Includes code for task-specific seq-to-seq data formatting.

Expand Down Expand Up @@ -36,7 +37,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import os
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import datasets as hf_datasets
import huggingface_hub as hf_hub
Expand All @@ -57,6 +58,23 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']

PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
Conversation = Union[PromptResponseDict, ChatFormattedDict]
ConversationType = Literal['prompt_response', 'chat']
TokenizedConversation = Dict[str, List[Union[int, str]]]


def _get_conversation_type(conversation_example: Conversation):
milocress marked this conversation as resolved.
Show resolved Hide resolved
# note: this function does not validate the conversation types,
# it merely determines which validator to use.
if 'messages' in conversation_example:
return 'chat'
elif 'prompt' in conversation_example or 'response' in conversation_example:
return 'prompt_response'
else:
raise KeyError(f'unknown conversation type {conversation_example=}')


def _is_empty_or_nonexistent(dirpath: str) -> bool:
"""Check if a directory is empty or non-existent.
Expand All @@ -70,9 +88,42 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:

def slice(s: str, sep: str):
# it seems like we can reuse this logic, as we likely have this pattern in other places.
slices = s.split(sep)
if len(slices) < 2:
raise ValueError(f'separator not in string. {sep=}, {s=}')
a, b = sep.join(slices[:-1]), sep + slices[-1]
return a, b

messages = example['messages']

if len(messages) < 2:
raise ValueError(
f'chat example must have at least two messages. {messages=}')
last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
for message in messages:
if 'role' not in message or 'content' not in message:
raise KeyError(f'message must have role and content. {message=}')

applied_template = tokenizer.apply_chat_template(messages, tokenize=False)
prompt, response = slice(applied_template, last_message['content'])
return {
'input_ids': tokenizer.tokenize(prompt),
'labels': tokenizer.tokenize(response)
milocress marked this conversation as resolved.
Show resolved Hide resolved
}


def _tokenize_prompt_response_formatted_example(
example: PromptResponseDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
Expand Down Expand Up @@ -108,6 +159,23 @@ def _tokenize_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
example: Conversation,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:
example_format = _get_conversation_type(example)
print(f'{example_format=}')
milocress marked this conversation as resolved.
Show resolved Hide resolved

if example_format == 'chat':
chat_example: ChatFormattedDict = example # type: ignore
irenedea marked this conversation as resolved.
Show resolved Hide resolved
milocress marked this conversation as resolved.
Show resolved Hide resolved
return _tokenize_chat_formatted_example(chat_example, tokenizer)
elif example_format == 'prompt_response':
prompt_response_example: PromptResponseDict = example # type: ignore
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise ValueError(f'unknown conversation type {example_format=}')
milocress marked this conversation as resolved.
Show resolved Hide resolved


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Expand Down
93 changes: 84 additions & 9 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
# isort: skip_file
import contextlib
import os
import pathlib
Expand All @@ -9,7 +10,7 @@
from argparse import Namespace
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import ContextManager, Literal, Optional, Union
from typing import ContextManager, List, Literal, Optional, Union
from unittest.mock import MagicMock

import pytest
Expand All @@ -23,11 +24,10 @@
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
_tokenize_formatted_example)
from llmfoundry.data.finetuning.tasks import (
_ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, ChatFormattedDict,
PromptResponseDict, _tokenize_formatted_example)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand Down Expand Up @@ -428,27 +428,102 @@ def test_tokenize_example_malformed():
'response': 'response',
milocress marked this conversation as resolved.
Show resolved Hide resolved
'completion': 'completion'
}
no_content = {'messages': [{'role': 'user'}]}
ends_with_user_role: ChatFormattedDict = {
milocress marked this conversation as resolved.
Show resolved Hide resolved
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'assistant',
'content': 'Hi, User!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}
no_assistant_message: ChatFormattedDict = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}

malformed_examples = [
malformed_prompt_response_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]
malformed_chat_examples = [
milocress marked this conversation as resolved.
Show resolved Hide resolved
no_content, ends_with_user_role, no_assistant_message
]

for example in malformed_examples:
for example in malformed_prompt_response_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())

my_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {})
for example in malformed_chat_examples:
with pytest.raises(Exception):
_tokenize_formatted_example(
example, my_tokenizer
) # type: ignore (the typing here is supposed to be malformed)


def test_tokenize_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:
example = {prompt_key: 'prompt', response_key: 'response'}

example: PromptResponseDict = {
prompt_key: 'prompt',
response_key: 'response'
}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example

chat_examples: List[ChatFormattedDict] = [
{
'messages': [{
'role': 'user',
'content': 'Hello, GPT'
}, {
'role': 'assistant',
'content': 'this is my response'
}]
}, # prompt/response but in chat format
{
'messages': [
{
'role': 'user',
'content': 'Hello, GPT'
},
{
'role': 'assistant',
'content': 'this is my response'
},
{
'role': 'user',
'content': 'Nice to hear that.'
},
{
'role': 'assistant',
'content': 'multi-way chat works too!'
},
]
}, # multi-way chat
]

chat_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {})
for chat_example in chat_examples:
tokenized_example = _tokenize_formatted_example(chat_example,
chat_tokenizer)
milocress marked this conversation as resolved.
Show resolved Hide resolved
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
Expand Down
Loading