Skip to content

Commit

Permalink
fix conflicting formatting linting guidelines
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jan 18, 2024
1 parent 4772ba2 commit 5f5a144
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 13 deletions.
76 changes: 72 additions & 4 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
# isort:skip_file

"""Includes code for task-specific seq-to-seq data formatting.
Expand Down Expand Up @@ -36,7 +37,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import os
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import datasets as hf_datasets
import huggingface_hub as hf_hub
Expand All @@ -57,6 +58,23 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']

PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
Conversation = PromptResponseDict | ChatFormattedDict
ConversationType = Literal['prompt_response', 'chat']
TokenizedConversation = Dict[str, List[int | str]]


def _get_conversation_type(conversation_example: Conversation):
# note: this function does not validate the conversation types,
# it merely determines which validator to use.
if 'messages' in conversation_example:
return 'chat'
elif 'prompt' in conversation_example or 'response' in conversation_example:
return 'prompt_response'
else:
raise KeyError(f'unknown conversation type {conversation_example=}')


def _is_empty_or_nonexistent(dirpath: str) -> bool:
"""Check if a directory is empty or non-existent.
Expand All @@ -70,9 +88,42 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:

def slice(s: str, sep: str):
# it seems like we can reuse this logic, as we likely have this pattern in other places.
slices = s.split(sep)
if len(slices) < 2:
raise ValueError(f'separator not in string. {sep=}, {s=}')
a, b = sep.join(slices[:-1]), sep + slices[-1]
return a, b

messages = example['messages']

if len(messages) < 2:
raise ValueError(
f'chat example must have at least two messages. {messages=}')
last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
for message in messages:
if 'role' not in message or 'content' not in message:
raise KeyError(f'message must have role and content. {message=}')

applied_template = tokenizer.apply_chat_template(messages, tokenize=False)
prompt, response = slice(applied_template, last_message['content'])
return {
'input_ids': tokenizer.tokenize(prompt),
'labels': tokenizer.tokenize(response)
}


def _tokenize_prompt_response_formatted_example(
example: PromptResponseDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
Expand Down Expand Up @@ -108,6 +159,23 @@ def _tokenize_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
example: Conversation,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:
example_format = _get_conversation_type(example)
print(f'{example_format=}')

if example_format == 'chat':
chat_example: ChatFormattedDict = example # type: ignore
return _tokenize_chat_formatted_example(chat_example, tokenizer)
elif example_format == 'prompt_response':
prompt_response_example: PromptResponseDict = example # type: ignore
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise ValueError(f'unknown conversation type {example_format=}')


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down
93 changes: 84 additions & 9 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
# isort: skip_file
import contextlib
import os
import pathlib
Expand All @@ -9,7 +10,7 @@
from argparse import Namespace
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import ContextManager, Literal, Optional, Union
from typing import ContextManager, List, Literal, Optional, Union
from unittest.mock import MagicMock

import pytest
Expand All @@ -23,11 +24,10 @@
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
_tokenize_formatted_example)
from llmfoundry.data.finetuning.tasks import (
_ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, ChatFormattedDict,
PromptResponseDict, _tokenize_formatted_example)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand Down Expand Up @@ -428,27 +428,102 @@ def test_tokenize_example_malformed():
'response': 'response',
'completion': 'completion'
}
no_content = {'messages': [{'role': 'user'}]}
ends_with_user_role: ChatFormattedDict = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'assistant',
'content': 'Hi, User!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}
no_assistant_message: ChatFormattedDict = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}

malformed_examples = [
malformed_prompt_response_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]
malformed_chat_examples = [
no_content, ends_with_user_role, no_assistant_message
]

for example in malformed_examples:
for example in malformed_prompt_response_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())

my_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {})
for example in malformed_chat_examples:
with pytest.raises(Exception):
_tokenize_formatted_example(
example, my_tokenizer
) # type: ignore (the typing here is supposed to be malformed)


def test_tokenize_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:
example = {prompt_key: 'prompt', response_key: 'response'}

example: PromptResponseDict = {
prompt_key: 'prompt',
response_key: 'response'
}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example

chat_examples: List[ChatFormattedDict] = [
{
'messages': [{
'role': 'user',
'content': 'Hello, GPT'
}, {
'role': 'assistant',
'content': 'this is my response'
}]
}, # prompt/response but in chat format
{
'messages': [
{
'role': 'user',
'content': 'Hello, GPT'
},
{
'role': 'assistant',
'content': 'this is my response'
},
{
'role': 'user',
'content': 'Nice to hear that.'
},
{
'role': 'assistant',
'content': 'multi-way chat works too!'
},
]
}, # multi-way chat
]

chat_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {})
for chat_example in chat_examples:
tokenized_example = _tokenize_formatted_example(chat_example,
chat_tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
Expand Down

0 comments on commit 5f5a144

Please sign in to comment.