From ef7c78e17a970b3742acfd5d07c41a8c62db7e15 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 19 Jan 2024 15:19:40 +0000 Subject: [PATCH] address comments --- llmfoundry/data/finetuning/tasks.py | 11 ++++------- tests/data/test_dataloader.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index c873527968..c201e3edb9 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -60,11 +60,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: PromptResponseDict = Dict[str, str] ChatFormattedDict = Dict[str, List[Dict[str, str]]] Conversation = Union[PromptResponseDict, ChatFormattedDict] -ConversationType = Literal['prompt_response', 'chat'] +ExampleType = Literal['prompt_response', 'chat'] TokenizedConversation = Dict[str, List[Union[int, str]]] -def _get_conversation_type(conversation_example: Conversation): +def _get_example_type(conversation_example: Conversation) -> ExampleType: # note: this function does not validate the conversation types, # it merely determines which validator to use. if 'messages' in conversation_example: @@ -114,10 +114,7 @@ def slice(s: str, sep: str): applied_template = tokenizer.apply_chat_template(messages, tokenize=False) prompt, response = slice(applied_template, last_message['content']) - return { - 'input_ids': tokenizer.tokenize(prompt), - 'labels': tokenizer.tokenize(response) - } + return tokenizer(text=prompt, text_target=response) def _tokenize_prompt_response_formatted_example( @@ -161,7 +158,7 @@ def _tokenize_prompt_response_formatted_example( def _tokenize_formatted_example( example: Conversation, tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation: - example_format = _get_conversation_type(example) + example_format = _get_example_type(example) if example_format == 'chat': chat_example: ChatFormattedDict = example # type: ignore diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 41fed4ae7c..9e983c7e5a 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -514,7 +514,7 @@ def test_tokenize_example_well_formed(): }, # multi-way chat ] - chat_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {}) + chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-chat', {}) for chat_example in chat_examples: tokenized_example = _tokenize_formatted_example(chat_example, chat_tokenizer)