Skip to content

Commit

Permalink
Add sharegpt tests plus death by pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
alextrott16 committed Apr 11, 2024
1 parent 2ecfcba commit cf4992a
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 50 deletions.
26 changes: 12 additions & 14 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,8 @@ def print_registered_tasks(self) -> None:
log.info('\n'.join(tasks))

def get_preprocessing_fn_from_dict(
self,
mapping: Dict[str,
str]) -> Callable[[Dict[str, Any]], Dict[str, str]]:
self, mapping: Dict[str,
str]) -> Callable[[Dict[str, Any]], Example]:
"""Get a preprocessing function from a dictionary.
The dictionary maps column names in the dataset to "prompt" and "response".
Expand Down Expand Up @@ -652,7 +651,7 @@ def get_preprocessing_fn_from_str(
self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None
) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
) -> Optional[Callable[[Dict[str, Any]], Example]]:
"""Get a preprocessing function from a string.
String can be either a registered function or an import path.
Expand Down Expand Up @@ -700,7 +699,7 @@ def get_preprocessing_fn_from_str(

def build_from_hf(
self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int,
preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]],
preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]],
tokenizer: PreTrainedTokenizerBase, target_prompts: str,
target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str,
Any]
Expand Down Expand Up @@ -783,7 +782,8 @@ def build_from_hf(

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return tokenize_formatted_example(preprocessing_fn(example),
tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
Expand Down Expand Up @@ -847,7 +847,7 @@ def build_from_streaming(self, *args: Any,


@dataset_constructor.register('tatsu-lab/alpaca')
def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
def alpaca_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Split out prompt/response from text."""
try:
prompt, response = inp['text'].split('### Response:')
Expand All @@ -859,7 +859,7 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
def dolly_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Format the text string."""
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
try:
Expand All @@ -875,7 +875,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('bigscience/P3')
def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
def p3_preprocessing_function(inp: Dict) -> PromptResponseDict:
"""Format the already-split example."""
return {
'prompt': inp['inputs'] + ':',
Expand All @@ -885,7 +885,7 @@ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:

# Muennighoff's P3 and flan datasets share a similar convention
@dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict:
"""Format the already-split example."""
try:
prompt: str = inp['inputs']
Expand All @@ -901,19 +901,17 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:


@dataset_constructor.register('teknium/OpenHermes-2.5')
def shareGPT_format_preprocessor(inp: Dict) -> Dict[str, List[Dict[str, str]]]:
def shareGPT_format_preprocessor(inp: Dict) -> ChatFormattedDict:
"""Convert from ShareGPT format to our chat format."""
role_map = {
'human': 'user',
'gpt': 'assistant',
'system': 'system',
'tool': 'tool'
}
try:
conversation = inp['conversations']
messages: List[Dict[str, str]] = []
for message in conversation:
role: str = role_map[message['from']]
role: str = role_map.get(message['from'], message['from'])
content: str = message['value']
messages.append({'role': role, 'content': content})
except Exception as e:
Expand Down
15 changes: 9 additions & 6 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def main(args: Namespace) -> None:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
formatted_sample = preprocessing_fn(sample)
assert isinstance(formatted_sample, dict)

# Use the _get_example_type utility to confirm that the formatted sample
# can be interpreted by the tokenization code
Expand Down Expand Up @@ -300,13 +301,15 @@ def main(args: Namespace) -> None:
out.write(sample_to_write)
else:
if example_type == 'prompt_response':
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in ['prompt', 'response']
}
encoded_sample = {}
for key in ['prompt', 'response']:
value = formatted_sample[key]
assert isinstance(value, str)
encoded_sample[key] = value.encode('utf-8')
out.write(encoded_sample)
else:
encoded_sample = formatted_sample
out.write(encoded_sample)
out.write(formatted_sample)

if tokenizer is not None and examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
Expand Down
85 changes: 85 additions & 0 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,3 +1093,88 @@ def test_build_unknown_dataloader():
tokenizer = MagicMock()
with pytest.raises(catalogue.RegistryError):
_ = build_dataloader(cfg, tokenizer, 2)


@pytest.mark.parametrize(
','.join(invalid_conversation_params),
generate_exclusive_test_params(invalid_conversation_params))
def test_sharegpt_format(tmp_path: pathlib.Path,
add_invalid_last_chat_message: bool,
add_invalid_message_key_quantity: bool,
add_invalid_content_type: bool, add_invalid_role: bool,
add_not_alternating_roles: bool):
tokenizer_name = 'databricks/dbrx-instruct'
max_seq_len = 2048
dataset_size = 5
device_batch_size = 5
tiny_dataset_folder_path = tmp_path
tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl')

tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_kwargs={
'model_max_length': max_seq_len,
'trust_remote_code': True
},
)
tokenizer.add_special_tokens({
'pad_token': '<pad>',
'bos_token': '<bos>',
'eos_token': '<eos>',
})

if dist.get_global_rank() == 0:
make_tiny_conversation_ft_dataset(
path=tiny_dataset_path,
size=dataset_size,
add_invalid_last_chat_message=add_invalid_last_chat_message,
add_invalid_message_key_quantity=add_invalid_message_key_quantity,
add_invalid_content_type=add_invalid_content_type,
add_invalid_role=add_invalid_role,
add_not_alternating_roles=add_not_alternating_roles,
use_messages_format=False,
)

cfg = {
'name': 'finetuning',
'dataset': {
'hf_name': str(tiny_dataset_folder_path),
'preprocessing_fn': 'teknium/OpenHermes-2.5',
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 0,
'prefetch_factor': None,
'pin_memory': False,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

error_context = contextlib.nullcontext()
if add_invalid_last_chat_message:
error_context = pytest.raises(InvalidLastChatMessageRoleError,
match='Invalid last message role:')
if add_invalid_message_key_quantity:
error_context = pytest.raises(IncorrectMessageKeyQuantityError,
match='Expected 2 keys in message')
if add_invalid_content_type:
error_context = pytest.raises(InvalidContentTypeError,
match='Expected content to be')
if add_invalid_role:
error_context = pytest.raises(InvalidRoleError,
match='Expected role to be one of')

if add_not_alternating_roles:
error_context = pytest.raises(ConsecutiveRepeatedChatRolesError,
match='Conversation roles must alternate')

with error_context:
build_finetuning_dataloader(cfg, tokenizer,
device_batch_size).dataloader
95 changes: 66 additions & 29 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
_slice_chat_formatted_example,
dataset_constructor,
tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer

Expand Down Expand Up @@ -175,37 +176,73 @@ def test_tokenize_instruct_example_well_formed():
assert 'labels' in tokenized_example['turns'][0]


@pytest.mark.parametrize(
'tokenizer_name',
['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'])
def test_multi_turn_chat_slicing(tokenizer_name: str):
convo = [
{
'role': 'system',
'content': 'everyone thinks you are so cool'
},
{
'role': 'user',
'content': 'hiiii'
},
{
'role': 'assistant',
'content': 'yassss'
},
{
'role': 'user',
'content': 'HIIIIII!!!'
},
{
'role': 'assistant',
'content': 'YASSSSSS'
},
]

tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
@pytest.mark.parametrize('tokenizer_name', [
'databricks/dbrx-instruct', 'EleutherAI/gpt-neox-20b',
'HuggingFaceH4/zephyr-7b-beta', 't5-base'
])
@pytest.mark.parametrize('messages_format', [True, False])
def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
if messages_format:
convo = [
{
'role': 'system',
'content': 'everyone thinks you are so cool'
},
{
'role': 'user',
'content': 'hiiii'
},
{
'role': 'assistant',
'content': 'yassss'
},
{
'role': 'user',
'content': 'HIIIIII!!!'
},
{
'role': 'assistant',
'content': 'YASSSSSS'
},
]
else:
convo = [
{
'from': 'system',
'value': 'everyone thinks you are so cool'
},
{
'from': 'human',
'value': 'hiiii'
},
{
'from': 'gpt',
'value': 'yassss'
},
{
'from': 'tool',
'value': 'HIIIIII!!!'
},
{
'from': 'gpt',
'value': 'YASSSSSS'
},
]
tmp = {'conversations': convo}
preprocessor = dataset_constructor.get_preprocessing_fn_from_str(
'teknium/OpenHermes-2.5')
assert preprocessor is not None
convo = preprocessor(tmp)['messages']
assert isinstance(convo, list)

example = {'messages': convo}

tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name,
trust_remote_code='dbrx'
in tokenizer_name)

templated_prompt_response_turns = _slice_chat_formatted_example(
{'messages': convo}, tok)
example, tok)

reconstructed_chat = ''
for prompt, response in templated_prompt_response_turns:
Expand Down
21 changes: 20 additions & 1 deletion tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
from argparse import Namespace
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional

from omegaconf import DictConfig
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -99,6 +99,7 @@ def make_tiny_conversation_ft_dataset(
add_invalid_content_type: bool = False,
add_invalid_role: bool = False,
add_not_alternating_roles: bool = False,
use_messages_format: bool = True,
):
if Path(path).suffix != '.jsonl':
raise ValueError(f'Path {path} must be a jsonl file.')
Expand Down Expand Up @@ -198,6 +199,24 @@ def make_tiny_conversation_ft_dataset(
}]
})

def messages_to_conversation(sample: Dict):
assert 'messages' in sample
messages = sample['messages']

role_map = {
'user': 'human',
'assistant': 'gpt',
}
conversations: List[Dict[str, str]] = []
for message in messages:
role: str = role_map.get(message['role'], message['role'])
content: str = message['content']
conversations.append({'from': role, 'value': content})
return {'conversations': conversations}

if not use_messages_format:
samples = [messages_to_conversation(sample) for sample in samples]

os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
for sample in samples:
Expand Down

0 comments on commit cf4992a

Please sign in to comment.