Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jan 19, 2024
1 parent fda8ab3 commit ef7c78e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
11 changes: 4 additions & 7 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
Conversation = Union[PromptResponseDict, ChatFormattedDict]
ConversationType = Literal['prompt_response', 'chat']
ExampleType = Literal['prompt_response', 'chat']
TokenizedConversation = Dict[str, List[Union[int, str]]]


def _get_conversation_type(conversation_example: Conversation):
def _get_example_type(conversation_example: Conversation) -> ExampleType:
# note: this function does not validate the conversation types,
# it merely determines which validator to use.
if 'messages' in conversation_example:
Expand Down Expand Up @@ -114,10 +114,7 @@ def slice(s: str, sep: str):

applied_template = tokenizer.apply_chat_template(messages, tokenize=False)
prompt, response = slice(applied_template, last_message['content'])
return {
'input_ids': tokenizer.tokenize(prompt),
'labels': tokenizer.tokenize(response)
}
return tokenizer(text=prompt, text_target=response)


def _tokenize_prompt_response_formatted_example(
Expand Down Expand Up @@ -161,7 +158,7 @@ def _tokenize_prompt_response_formatted_example(
def _tokenize_formatted_example(
example: Conversation,
tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation:
example_format = _get_conversation_type(example)
example_format = _get_example_type(example)

if example_format == 'chat':
chat_example: ChatFormattedDict = example # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def test_tokenize_example_well_formed():
}, # multi-way chat
]

chat_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {})
chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-chat', {})
for chat_example in chat_examples:
tokenized_example = _tokenize_formatted_example(chat_example,
chat_tokenizer)
Expand Down

0 comments on commit ef7c78e

Please sign in to comment.