Skip to content

Commit

Permalink
Fix slow GemmaTokenizer and improve SPM slow -> fast conversion proce…
Browse files Browse the repository at this point in the history
…ss (#32191)

* Remove user-defined tokens which can be obtained through merges

* Remove debug line

* formatting

* Refactor spm slow -> fast converter

* revert unnecessary refactor

* set comprehension

* remove test files

* Use `vocab_scores`

* Always replace spiece underline with space in decode

* we no longer need token filtering

* Add save fast load slow unit test

* Remove tokenizers version check

* Remove duplicate code

* Make `<start_of_turn>` and `<end_of_turn>` special tokens

* Bias merge priority with length if score is the same

* Add unit test for merge priority

* CI
  • Loading branch information
xenova authored Jul 30, 2024
1 parent 026a173 commit 6e2d04e
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 151 deletions.
234 changes: 84 additions & 150 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
return prepend_scheme


def generate_merges(vocab, vocab_scores):
reverse = vocab_scores is not None
vocab_scores = dict(vocab_scores) if reverse else vocab

merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)

merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return merges


class SentencePieceExtractor:
"""
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
Expand All @@ -73,24 +92,8 @@ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}

if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False

# Merges
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = generate_merges(vocab, vocab_scores)

merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return vocab, merges


Expand All @@ -107,24 +110,7 @@ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
# "<0x09>" is the bytefallback for `\t`
vocab["\t"] = vocab.get("<0x09>")

if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False

# Merges
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)

merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
merges = generate_merges(vocab, vocab_scores)
return vocab, merges


Expand Down Expand Up @@ -544,6 +530,10 @@ def converted(self) -> Tokenizer:


class SpmConverter(Converter):
handle_byte_fallback = False
SpmExtractor = SentencePieceExtractor
special_tokens = {}

def __init__(self, *args):
requires_backends(self, "protobuf")

Expand All @@ -557,14 +547,13 @@ def __init__(self, *args):
m.ParseFromString(f.read())
self.proto = m

if self.proto.trainer_spec.byte_fallback:
if not getattr(self, "handle_byte_fallback", None):
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)

def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
Expand All @@ -575,26 +564,72 @@ def unk_id(self, proto):
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
unk_id = self.unk_id(proto)

if model_type == 1:
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
tokenizer = Tokenizer(
Unigram(
vocab_scores,
unk_id=self.unk_id(proto),
byte_fallback=self.handle_byte_fallback,
)
)

elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
_, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
byte_fallback=self.handle_byte_fallback,
dropout=None,
)
)

else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)

# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
spm_added_tokens = [
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
for id, p in enumerate(proto.pieces)
if p.type in [3, 4]
]
tokens_to_add = [
AddedToken(token, normalized=False, special=special)
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
]

if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
# individual tokens would repeatedly rebuild a trie, which can be slow.
is_last_special = None
tokens = []
for token in tokens_to_add:
is_special = token.special
if is_last_special is None or is_last_special == is_special:
tokens.append(token)
else:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
tokens = [token]
is_last_special = is_special
if tokens:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)

return tokenizer

def normalizer(self, proto):
Expand Down Expand Up @@ -622,40 +657,6 @@ def decoder(self, replacement, add_prefix_space):
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)

# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)

tokens_to_add = {
id: AddedToken(token, normalized=False, special=special)
for id, token, special in [
(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]
]
}
tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])]
if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
# individual tokens would repeatedly rebuild a trie, which can be slow.
is_last_special = None
tokens = []
for token in tokens_to_add:
is_special = token.special
if is_last_special is None or is_last_special == is_special:
tokens.append(token)
else:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
tokens = [token]
is_last_special = is_special
if tokens:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
# Tokenizer assemble
normalizer = self.normalizer(self.proto)
if normalizer is not None:
Expand Down Expand Up @@ -1283,6 +1284,9 @@ def post_processor(self):

class GemmaConvert(SpmConverter):
handle_byte_fallback = True
SpmExtractor = GemmaSentencePieceExtractor
# start and end of turn tokens must be marked as special
special_tokens = {"<start_of_turn>", "<end_of_turn>"}

""""
split_by_unicode_script: true
Expand Down Expand Up @@ -1327,45 +1331,6 @@ def decoder(self, replacement, add_prefix_space):
]
)

def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers

if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))

elif model_type == 2:
_, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}

tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
byte_fallback=True,
dropout=None,
)
)
tokenizer.add_special_tokens(
[
AddedToken("<pad>", normalized=False, special=True),
AddedToken("<eos>", normalized=False, special=True),
AddedToken("<bos>", normalized=False, special=True),
AddedToken("<unk>", normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer


class LlamaConverter(SpmConverter):
handle_byte_fallback = True
Expand Down Expand Up @@ -1393,37 +1358,6 @@ def decoder(self, replacement, add_prefix_space):
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)

def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers

if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))

elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)

return tokenizer

def normalizer(self, proto):
if getattr(self.original_tokenizer, "legacy", True):
sequence = []
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/tokenization_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _decode(
else:
sub_texts = "".join(sub_texts)

return sub_texts
return sub_texts.replace(SPIECE_UNDERLINE, " ")

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
Expand Down
35 changes: 35 additions & 0 deletions tests/models/gemma/test_tokenization_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ def test_fast_special_tokens(self):
self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False

def test_fast_merge_priority(self):
slow_tokenizer = self.tokenizer
fast_tokenizer = self.rust_tokenizer
text = " "
target = [168, 153]
slow = slow_tokenizer.encode(text, add_special_tokens=False)
assert slow == target

fast = fast_tokenizer.encode(text, add_special_tokens=False)
assert fast == target

@unittest.skip(reason="Not super important and always failing. Let's skip it")
@slow
def test_conversion(self):
Expand Down Expand Up @@ -442,6 +453,30 @@ def test_tokenization_for_chat(self):
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)

def test_save_fast_load_slow(self):
# Ensure that we can save a fast tokenizer and load it as a slow tokenizer
slow_tokenizer = self.tokenizer
text = "a "
target_encoded = [2, 235250, 139]
slow = slow_tokenizer.encode(text, add_special_tokens=True)
assert slow == target_encoded

slow_decoded = slow_tokenizer.decode(slow, skip_special_tokens=True)
assert slow_decoded == text

with tempfile.TemporaryDirectory() as dirname:
# Save fast tokenizer
self.rust_tokenizer.save_pretrained(dirname)

# Load slow tokenizer with fast files present in the directory
slow_tokenizer_from_fast = GemmaTokenizer.from_pretrained(dirname)

slow_from_fast = slow_tokenizer_from_fast.encode(text, add_special_tokens=True)
assert slow_from_fast == target_encoded

slow_from_fast_decoded = slow_tokenizer_from_fast.decode(slow, skip_special_tokens=True)
assert slow_from_fast_decoded == text


@require_sentencepiece
@require_tokenizers
Expand Down

0 comments on commit 6e2d04e

Please sign in to comment.