Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jan 22, 2024
1 parent 224f724 commit ca8c02d
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 137 deletions.
21 changes: 8 additions & 13 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:


def _get_example_type(example: Example) -> ExampleType:
"""
Determines the type of the input example.
"""Determines the type of the input example.
Args:
example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair.
Expand All @@ -76,7 +75,6 @@ def _get_example_type(example: Example) -> ExampleType:
Raises:
KeyError: If the example type is unknown.
"""
if 'messages' in example:
return 'chat'
Expand All @@ -101,20 +99,20 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
def _slice_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""
Applies the tokenizer's chat template to the example messages and slices the resulting templated string into a prompt and a completion.
"""Slices the chat example into a formatted prompt and response.
Args:
example (ChatFormattedDict): The chat example containing the messages.
tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template.
Returns:
Tuple[str, str]: The prompt and response as separate strings.
Raises:
ValueError: If the chat example has less than two messages or if the last message is not from the assistant.
KeyError: If a message does not have a role or content.
"""

def slice(s: str, sep: str):
# it seems like we can reuse this logic, as we likely have this pattern in other places.
slices = s.split(sep)
Expand Down Expand Up @@ -144,16 +142,14 @@ def slice(s: str, sep: str):
def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""
Tokenizes a chat-formatted example using the provided tokenizer.
"""Tokenizes a chat-formatted example using the provided tokenizer.
Args:
example (ChatFormattedDict): The chat-formatted example to tokenize.
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
Returns:
TokenizedExample: The tokenized example.
"""
prompt, response = _slice_chat_formatted_example(example, tokenizer)
return tokenizer(text=prompt, text_target=response)
Expand Down Expand Up @@ -200,8 +196,7 @@ def _tokenize_prompt_response_formatted_example(
def _tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""
Tokenizes a formatted example using the provided tokenizer.
"""Tokenizes a formatted example using the provided tokenizer.
Args:
example (Example): The input example to be tokenized.
Expand Down
139 changes: 139 additions & 0 deletions tests/data/test_chat_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import MagicMock

import pytest
import transformers

from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
_slice_chat_formatted_example,
_tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer


def test_tokenize_chat_example_malformed():
no_content = {'messages': [{'role': 'user'}]}
ends_with_user_role = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'assistant',
'content': 'Hi, User!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}
no_assistant_message = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}
malformed_chat_examples = [
no_content, ends_with_user_role, no_assistant_message
]
my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {})
for example in malformed_chat_examples:
with pytest.raises(Exception):
_tokenize_formatted_example(
example, my_tokenizer
) # type: ignore (the typing here is supposed to be malformed)


def test_tokenize_chat_example_well_formed():
chat_examples = [
{
'messages': [{
'role': 'user',
'content': 'Hello, GPT'
}, {
'role': 'assistant',
'content': 'this is my response'
}]
}, # prompt/response but in chat format
{
'messages': [
{
'role': 'user',
'content': 'Hello, GPT'
},
{
'role': 'assistant',
'content': 'this is my response'
},
{
'role': 'user',
'content': 'Nice to hear that.'
},
{
'role': 'assistant',
'content': 'multi-way chat works too!'
},
]
}, # multi-way chat
]

expected = [
{
'prompt': 'TODO: fix',
'response': 'TODO: fix'
},
{
'prompt': 'TODO: fix',
'response': 'TODO: fix'
},
]

chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {})
assert len(expected) == len(
chat_examples) # if we add a new example, zip shouldn't fail silently
for chat_example, expected_stringification in zip(chat_examples, expected):
prompt, response = _slice_chat_formatted_example(
chat_example, chat_tokenizer)
tokenized_example = _tokenize_formatted_example(chat_example,
chat_tokenizer)
assert prompt == expected_stringification['prompt']
assert response == expected_stringification['response']
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


def test_tokenize_instruct_example_malformed():
no_keys = {}
no_prompt_key = {'response': 'response'}
no_response_key = {'prompt': 'prompt'}
extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'}
extra_keys_with_response = {'response': 'response', 'extra': 'extra'}
multiple_allowed_response_keys = {
'prompt': 'prompt',
'response': 'response',
'completion': 'completion'
}

malformed_prompt_response_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]

for example in malformed_prompt_response_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())


def test_tokenize_instruct_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:

example = {prompt_key: 'prompt', response_key: 'response'}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example
126 changes: 2 additions & 124 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
_slice_chat_formatted_example,
_tokenize_formatted_example)
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand Down Expand Up @@ -428,124 +424,6 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
shutil.rmtree(tiny_dataset_folder_path)


def test_tokenize_instruct_example_malformed():
no_keys = {}
no_prompt_key = {'response': 'response'}
no_response_key = {'prompt': 'prompt'}
extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'}
extra_keys_with_response = {'response': 'response', 'extra': 'extra'}
multiple_allowed_response_keys = {
'prompt': 'prompt',
'response': 'response',
'completion': 'completion'
}

malformed_prompt_response_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]

for example in malformed_prompt_response_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())


def test_tokenize_chat_example_malformed():
no_content = {'messages': [{'role': 'user'}]}
ends_with_user_role = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'assistant',
'content': 'Hi, User!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}
no_assistant_message = {
'messages': [{
'role': 'user',
'content': 'Hello GPT!'
}, {
'role': 'user',
'content': 'user message not followed by an assistant label'
}]
}
malformed_chat_examples = [
no_content, ends_with_user_role, no_assistant_message
]
my_tokenizer = build_tokenizer('mosaicml/mpt-7b-chat', {})
for example in malformed_chat_examples:
with pytest.raises(Exception):
_tokenize_formatted_example(
example, my_tokenizer
) # type: ignore (the typing here is supposed to be malformed)


def test_tokenize_instruct_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:

example = {prompt_key: 'prompt', response_key: 'response'}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


def test_tokenize_chat_example_well_formed():
chat_examples = [
{
'messages': [{
'role': 'user',
'content': 'Hello, GPT'
}, {
'role': 'assistant',
'content': 'this is my response'
}]
}, # prompt/response but in chat format
{
'messages': [
{
'role': 'user',
'content': 'Hello, GPT'
},
{
'role': 'assistant',
'content': 'this is my response'
},
{
'role': 'user',
'content': 'Nice to hear that.'
},
{
'role': 'assistant',
'content': 'multi-way chat works too!'
},
]
}, # multi-way chat
]

chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-chat', {})
for chat_example in chat_examples:
last_message = chat_example['messages'][-1]['content']
earlier_messages = [
msg['content'] for msg in chat_example['messages'][:-1]
]
prompt, response = _slice_chat_formatted_example(
chat_example, chat_tokenizer)
tokenized_example = _tokenize_formatted_example(chat_example,
chat_tokenizer)
assert last_message in response
for earlier_message in earlier_messages:
assert earlier_message in prompt
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
tokenizer_name = 'gpt2'
Expand Down

0 comments on commit ca8c02d

Please sign in to comment.