Skip to content

Commit

Permalink
Fix extra BOS token in front of response for some tokenizers (#1003)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Mar 6, 2024
1 parent f4f6414 commit cf0f5e5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
29 changes: 28 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,29 @@ def _slice_chat_formatted_example(
return prompt, response


def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str,
text_target: str) -> TokenizedExample:
"""Tokenizes the prompt and response using the provided tokenizer.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
prompt (str): The prompt to tokenize.
response (str): The response to tokenize.
Returns:
TokenizedExample: The tokenized example.
"""
tokenized_sample = tokenizer(text=text, text_target=text_target)

# Remove the BOS token from the start of the labels if it was automatically added
if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
if tokenizer.bos_token_id is not None and tokenized_sample['labels'][
0] == tokenizer.bos_token_id:
tokenized_sample['labels'] = tokenized_sample['labels'][1:]

return tokenized_sample


def _tokenize_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
Expand Down Expand Up @@ -246,7 +269,11 @@ def _tokenize_prompt_response_formatted_example(
f'Unable to tokenize example because {response_key} was not a string. {example=}'
)

return tokenizer(text=prompt, text_target=response)
return _tokenize_with_bos_removal(
tokenizer=tokenizer,
text=prompt,
text_target=response,
)


def tokenize_formatted_example(
Expand Down
27 changes: 27 additions & 0 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,30 @@ def test_tokenize_instruct_example_well_formed():
tokenized_example = tokenize_formatted_example(example, tokenizer)
assert 'input_ids' in tokenized_example
assert 'labels' in tokenized_example


def test_tokenize_no_labels_bos_pr():
# This tokenizer automatically adds bos tokens
tokenizer = transformers.AutoTokenizer.from_pretrained(
'mistralai/Mixtral-8x7B-v0.1')

example = {'prompt': 'prompt', 'response': 'response'}

assert tokenizer.add_bos_token == True

tokenized_example = tokenize_formatted_example(example, tokenizer)

assert len(tokenized_example['labels']) == 1
assert tokenized_example['labels'][0] != tokenizer.bos_token_id
assert tokenized_example['input_ids'][0] == tokenizer.bos_token_id

# This tokenizer does not have the add_bos_token attribute
tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b')

assert not hasattr(tokenizer, 'add_bos_token')

tokenized_example = tokenize_formatted_example(example, tokenizer)

assert len(tokenized_example['labels']) == 1
assert tokenized_example['labels'][0] != tokenizer.bos_token_id
assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id

0 comments on commit cf0f5e5

Please sign in to comment.