Skip to content

Commit

Permalink
Pad tiktoken vocab so that additional_special_tokens works (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 26, 2023
1 parent c60657b commit 6f59738
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
33 changes: 31 additions & 2 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -26,7 +27,7 @@ def __init__(self,
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
pad_token: Optional[str] = None,
**kwargs: Dict[str, Any]):
**kwargs: Any):
"""Constructor creates a tiktoken tokenizer to use as the underlying.
tokenizer.
Expand Down Expand Up @@ -90,7 +91,17 @@ def is_fast(self) -> bool:
return False

def get_vocab(self) -> Dict[str, int]:
"""Returns vocab as a dict."""
"""Returns vocab as a dict.
Note: This function does not work properly due to difference in assumptions between tiktoken and Hugging Face tokenizers.
Most uses do not need to use get_vocab, so this is not a priority to fix.
"""
warnings.warn(
'get_vocab does not work properly with TiktokenTokenizerWrapper. Please do not rely on it being perfectly correct.'
+
' It will be called once init just to get the size of the vocab inside the base class.'
)

vocab = {}
for i in range(self.vocab_size):
try:
Expand All @@ -101,6 +112,24 @@ def get_vocab(self) -> Dict[str, int]:
except KeyError:
pass

# As far as I can tell, we don't require get_vocab to completely work,
# but when using additional_special_tokens, Hugging Face determines the next
# token index to add with len(self.get_vocab()) so we need the _size_ of this dictionary to be correct.
extra_id_index = 0
candidate_extra_id = f'<extra_id_{extra_id_index}>'
indices_to_fill_in = {i for i in range(self.vocab_size)} - set(
vocab.values())

# Add enough indices to make get_vocab() the right length
for index_to_add in indices_to_fill_in:
# Make sure we don't overwrite a token that already exists
while candidate_extra_id in vocab:
extra_id_index += 1
candidate_extra_id = f'<extra_id_{extra_id_index}>'

# Get an index to add and add the item
vocab[candidate_extra_id] = index_to_add

return vocab

def _tokenize(self, text: str) -> List[int]:
Expand Down
39 changes: 33 additions & 6 deletions tests/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import pathlib
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

import pytest
import transformers
Expand Down Expand Up @@ -49,15 +49,18 @@ def get_tokenizers_for_testing(
encoding_name: Optional[str],
tmp_path: pathlib.Path,
add_bos_token: bool = False,
add_eos_token: bool = False
add_eos_token: bool = False,
additional_special_tokens: Optional[List[str]] = None,
) -> Tuple[TiktokenTokenizerWrapper, TiktokenTokenizerWrapper, 'Encoding']:
tiktoken = pytest.importorskip('tiktoken')

# Construction
wrapped_tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token)
wrapped_tokenizer = TiktokenTokenizerWrapper(
model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
additional_special_tokens=additional_special_tokens)
if model_name is not None:
original_tokenizer = tiktoken.encoding_for_model(model_name)
else:
Expand Down Expand Up @@ -176,6 +179,10 @@ def test_tiktoken_vocab(model_name: Optional[str], encoding_name: Optional[str],

didnt_match = []
for key, value in wrapped_vocab.items():
# Skip checking the extra ids we pad the vocab with
if key.startswith('<extra_id') and key.endswith('>'):
continue

if original_tokenizer.encode(key, allowed_special='all') == [value]:
continue
else:
Expand Down Expand Up @@ -232,3 +239,23 @@ def test_tiktoken_encode_plus(model_name: Optional[str],
encoded_special_mask = encoded_outputs.special_tokens_mask
assert encoded_special_mask[0] == 1
assert encoded_special_mask[-1] == 1


@pytest.mark.parametrize('model_name,encoding_name',
MODEL_ENCODING_NAME_PARAMETRIZATION)
def test_additional_special_tokens(model_name: Optional[str],
encoding_name: Optional[str],
tmp_path: pathlib.Path):
special_token_to_add = '<|im_start|>'
wrapped_tokenizer, _, _ = get_tokenizers_for_testing(
model_name,
encoding_name,
tmp_path,
add_bos_token=False,
add_eos_token=False,
additional_special_tokens=[special_token_to_add])
encoded_outputs = wrapped_tokenizer(special_token_to_add +
' hello')['input_ids']

assert encoded_outputs[0] == wrapped_tokenizer.vocab_size
assert len(encoded_outputs) == 2

0 comments on commit 6f59738

Please sign in to comment.