diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 42b15e4d6e..4906cea151 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -77,7 +77,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_MESSAGES_KEYS = {'messages'} _ALLOWED_ROLE_KEYS = {'role'} _ALLOWED_CONTENT_KEYS = {'content'} -_ALLOWED_ROLES = {'user', 'assistant', 'system'} +_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'} _ALLOWED_LAST_MESSAGE_ROLES = {'assistant'} DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, @@ -217,7 +217,7 @@ def slice_out_last_turn( if conversation_through_previous_turn != prompt_with_history[:len( conversation_through_previous_turn)]: raise ValueError( - f'The prompt_with_histry must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}' + f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}' ) if prompt_with_history != full_conversation[:len(prompt_with_history)]: raise ValueError( @@ -624,9 +624,8 @@ def print_registered_tasks(self) -> None: log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( - self, - mapping: Dict[str, - str]) -> Callable[[Dict[str, Any]], Dict[str, str]]: + self, mapping: Dict[str, + str]) -> Callable[[Dict[str, Any]], Example]: """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -662,7 +661,7 @@ def get_preprocessing_fn_from_str( self, preprocessor: Optional[str], dataset_name: Optional[str] = None - ) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]: + ) -> Optional[Callable[[Dict[str, Any]], Example]]: """Get a preprocessing function from a string. String can be either a registered function or an import path. @@ -710,7 +709,7 @@ def get_preprocessing_fn_from_str( def build_from_hf( self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int, - preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]], + preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], tokenizer: PreTrainedTokenizerBase, target_prompts: str, target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str, Any] @@ -793,7 +792,8 @@ def build_from_hf( def dataset_mapper(example: Dict): if preprocessing_fn is not None: - example = preprocessing_fn(example) + return tokenize_formatted_example(preprocessing_fn(example), + tokenizer) return tokenize_formatted_example(example, tokenizer) detected_cpu_count = os.cpu_count() or 1 @@ -857,7 +857,7 @@ def build_from_streaming(self, *args: Any, @dataset_constructor.register('tatsu-lab/alpaca') -def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]: +def alpaca_preprocessing_function(inp: Dict) -> PromptResponseDict: """Split out prompt/response from text.""" try: prompt, response = inp['text'].split('### Response:') @@ -869,7 +869,7 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]: @dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k') -def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]: +def dolly_preprocessing_function(inp: Dict) -> PromptResponseDict: """Format the text string.""" PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n' try: @@ -885,7 +885,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]: @dataset_constructor.register('bigscience/P3') -def p3_preprocessing_function(inp: Dict) -> Dict[str, str]: +def p3_preprocessing_function(inp: Dict) -> PromptResponseDict: """Format the already-split example.""" return { 'prompt': inp['inputs'] + ':', @@ -895,7 +895,7 @@ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]: # Muennighoff's P3 and flan datasets share a similar convention @dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan') -def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]: +def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict: """Format the already-split example.""" try: prompt: str = inp['inputs'] @@ -908,3 +908,22 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]: except Exception as e: raise UnableToProcessPromptResponseError(inp) from e return {'prompt': prompt, 'response': response} + + +@dataset_constructor.register('teknium/OpenHermes-2.5') +def shareGPT_format_preprocessor(inp: Dict) -> ChatFormattedDict: + """Convert from ShareGPT format to our chat format.""" + role_map = { + 'human': 'user', + 'gpt': 'assistant', + } + try: + conversation = inp['conversations'] + messages: List[Dict[str, str]] = [] + for message in conversation: + role: str = role_map.get(message['from'], message['from']) + content: str = message['value'] + messages.append({'role': role, 'content': content}) + except Exception as e: + raise UnableToProcessPromptResponseError(inp) from e + return {'messages': messages} diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index 594e4f778f..e78e76a912 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -269,6 +269,7 @@ def main(args: Namespace) -> None: examples_removed = 0 for sample in tqdm(samples, desc=split_name): formatted_sample = preprocessing_fn(sample) + assert isinstance(formatted_sample, dict) # Use the _get_example_type utility to confirm that the formatted sample # can be interpreted by the tokenization code @@ -300,13 +301,15 @@ def main(args: Namespace) -> None: out.write(sample_to_write) else: if example_type == 'prompt_response': - encoded_sample = { - key: formatted_sample[key].encode('utf-8') - for key in ['prompt', 'response'] - } + encoded_sample = {} + for key in ['prompt', 'response']: + value = formatted_sample[key] + assert isinstance(value, str) + encoded_sample[key] = value.encode('utf-8') + out.write(encoded_sample) else: - encoded_sample = formatted_sample - out.write(encoded_sample) + out.write(formatted_sample) + if tokenizer is not None and examples_removed > 0: warnings.warn( f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, ' diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index f5c2631fa7..c99ae6baf2 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -1093,3 +1093,87 @@ def test_build_unknown_dataloader(): tokenizer = MagicMock() with pytest.raises(catalogue.RegistryError): _ = build_dataloader(cfg, tokenizer, 2) + + +invalid_conversation_params_sharegpt = [ + 'add_invalid_last_chat_message', 'add_invalid_content_type', + 'add_invalid_role', 'add_not_alternating_roles' +] + + +@pytest.mark.parametrize( + ','.join(invalid_conversation_params_sharegpt), + generate_exclusive_test_params(invalid_conversation_params_sharegpt)) +def test_sharegpt_format(tmp_path: pathlib.Path, + add_invalid_last_chat_message: bool, + add_invalid_content_type: bool, add_invalid_role: bool, + add_not_alternating_roles: bool): + tokenizer_name = 'mosaicml/mpt-7b' + max_seq_len = 2048 + dataset_size = 5 + device_batch_size = 5 + tiny_dataset_folder_path = tmp_path + tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl') + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + tokenizer.add_special_tokens({ + 'pad_token': '', + 'bos_token': '', + 'eos_token': '', + }) + + if dist.get_global_rank() == 0: + make_tiny_conversation_ft_dataset( + path=tiny_dataset_path, + size=dataset_size, + add_invalid_last_chat_message=add_invalid_last_chat_message, + add_invalid_message_key_quantity=False, + add_invalid_content_type=add_invalid_content_type, + add_invalid_role=add_invalid_role, + add_not_alternating_roles=add_not_alternating_roles, + use_messages_format=False, + ) + + cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': str(tiny_dataset_folder_path), + 'preprocessing_fn': 'teknium/OpenHermes-2.5', + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0 + } + + cfg = om.create(cfg) + + error_context = contextlib.nullcontext() + if add_invalid_last_chat_message: + error_context = pytest.raises(InvalidLastChatMessageRoleError, + match='Invalid last message role:') + if add_invalid_content_type: + error_context = pytest.raises(InvalidContentTypeError, + match='Expected content to be') + if add_invalid_role: + error_context = pytest.raises(InvalidRoleError, + match='Expected role to be one of') + + if add_not_alternating_roles: + error_context = pytest.raises(ConsecutiveRepeatedChatRolesError, + match='Conversation roles must alternate') + + with error_context: + build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index a45c4d8f0d..632a79dac9 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -9,6 +9,7 @@ from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS, _slice_chat_formatted_example, + dataset_constructor, tokenize_formatted_example) from llmfoundry.utils.builders import build_tokenizer @@ -178,34 +179,67 @@ def test_tokenize_instruct_example_well_formed(): @pytest.mark.parametrize( 'tokenizer_name', ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base']) -def test_multi_turn_chat_slicing(tokenizer_name: str): - convo = [ - { - 'role': 'system', - 'content': 'everyone thinks you are so cool' - }, - { - 'role': 'user', - 'content': 'hiiii' - }, - { - 'role': 'assistant', - 'content': 'yassss' - }, - { - 'role': 'user', - 'content': 'HIIIIII!!!' - }, - { - 'role': 'assistant', - 'content': 'YASSSSSS' - }, - ] +@pytest.mark.parametrize('messages_format', [True, False]) +def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): + if messages_format: + convo = [ + { + 'role': 'system', + 'content': 'everyone thinks you are so cool' + }, + { + 'role': 'user', + 'content': 'hiiii' + }, + { + 'role': 'assistant', + 'content': 'yassss' + }, + { + 'role': 'user', + 'content': 'HIIIIII!!!' + }, + { + 'role': 'assistant', + 'content': 'YASSSSSS' + }, + ] + else: + convo = [ + { + 'from': 'system', + 'value': 'everyone thinks you are so cool' + }, + { + 'from': 'human', + 'value': 'hiiii' + }, + { + 'from': 'gpt', + 'value': 'yassss' + }, + { + 'from': 'tool', + 'value': 'HIIIIII!!!' + }, + { + 'from': 'gpt', + 'value': 'YASSSSSS' + }, + ] + tmp = {'conversations': convo} + preprocessor = dataset_constructor.get_preprocessing_fn_from_str( + 'teknium/OpenHermes-2.5') + assert preprocessor is not None + convo = preprocessor(tmp)['messages'] + assert isinstance(convo, list) + + example = {'messages': convo} tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name) templated_prompt_response_turns = _slice_chat_formatted_example( - {'messages': convo}, tok) + example, tok) reconstructed_chat = '' for prompt, response in templated_prompt_response_turns: diff --git a/tests/data_utils.py b/tests/data_utils.py index 3c077b5e71..fd24d4cbbf 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -6,7 +6,7 @@ import shutil from argparse import Namespace from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional from omegaconf import DictConfig from omegaconf import OmegaConf as om @@ -99,6 +99,7 @@ def make_tiny_conversation_ft_dataset( add_invalid_content_type: bool = False, add_invalid_role: bool = False, add_not_alternating_roles: bool = False, + use_messages_format: bool = True, ): if Path(path).suffix != '.jsonl': raise ValueError(f'Path {path} must be a jsonl file.') @@ -198,6 +199,24 @@ def make_tiny_conversation_ft_dataset( }] }) + def messages_to_conversation(sample: Dict): + assert 'messages' in sample + messages = sample['messages'] + + role_map = { + 'user': 'human', + 'assistant': 'gpt', + } + conversations: List[Dict[str, str]] = [] + for message in messages: + role: str = role_map.get(message['role'], message['role']) + content: str = message['content'] + conversations.append({'from': role, 'value': content}) + return {'conversations': conversations} + + if not use_messages_format: + samples = [messages_to_conversation(sample) for sample in samples] + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'w') as _f: for sample in samples: