Skip to content

Commit

Permalink
Merge branch 'main' into shashank/seq_id_flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Dec 2, 2023
2 parents cb6864a + b2e4b0e commit 9bc7ce1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 13 deletions.
39 changes: 30 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,46 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

__all__ = ['dataset_constructor']

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}


def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
if ('prompt' not in example) or ('response' not in example):
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise KeyError(
f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}'
)

if len(response_keys) != 1:
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {example=}.'
f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\
f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}'
)
if not isinstance(example['prompt'], str):

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
response = example[response_key]

if not isinstance(prompt, str):
raise TypeError(
f'Unable to tokenize example because "prompt" was not a string. {example=}'
f'Unable to tokenize example because {prompt_key} was not a string. {example=}'
)
if not isinstance(example['response'], str):

if not isinstance(response, str):
raise TypeError(
f'Unable to tokenize example because "response" was not a string. {example=}'
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)
return tokenizer(text=example['prompt'], text_target=example['response'])

return tokenizer(text=prompt, text_target=response)


class StreamingFinetuningDataset(StreamingDataset):
Expand Down
42 changes: 38 additions & 4 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
_tokenize_formatted_example)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand Down Expand Up @@ -355,10 +358,8 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last:
error_context = pytest.raises(ValueError, match='Your dataset')
if invalid_dataset:
error_context = pytest.raises(
TypeError,
match='Unable to tokenize example because "prompt" was not a string'
)
error_context = pytest.raises(TypeError,
match='Unable to tokenize example')

with error_context:
_ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
Expand All @@ -367,6 +368,39 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
shutil.rmtree(tiny_dataset_folder_path)


def test_tokenize_example_malformed():
no_keys = {}
no_prompt_key = {'response': 'response'}
no_response_key = {'prompt': 'prompt'}
extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'}
extra_keys_with_response = {'response': 'response', 'extra': 'extra'}
multiple_allowed_response_keys = {
'prompt': 'prompt',
'response': 'response',
'completion': 'completion'
}

malformed_examples = [
no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt,
extra_keys_with_response, multiple_allowed_response_keys
]

for example in malformed_examples:
with pytest.raises(KeyError):
_tokenize_formatted_example(example, MagicMock())


def test_tokenize_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:
example = {prompt_key: 'prompt', response_key: 'response'}
tokenized_example = _tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


@pytest.mark.parametrize('split', ['train', 'custom', 'data'])
def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
tokenizer_name = 'gpt2'
Expand Down

0 comments on commit 9bc7ce1

Please sign in to comment.