Skip to content

Commit

Permalink
Fix/extend re replacement seq (dottxt-ai#948)
Browse files Browse the repository at this point in the history
This PR is an extension of
dottxt-ai#763, related to extending
the `re_replacement_seq` regex.

The new [NorwAI models](https://huggingface.co/NorwAI) use a tokenizer
that has the token `�.`, which leads to the same error as was described
in the previous issue
dottxt-ai#762.

This PR extends the fix from
dottxt-ai#763 to deal with this
case, as well as adding a unit test to test various tokenizers, and a
comment describing why we need the prefix and suffix in the regex.
  • Loading branch information
saattrupdan authored Jun 12, 2024
1 parent 11af6ce commit a987159
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
5 changes: 4 additions & 1 deletion outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,10 @@ def create_fsm_index_end_to_end(


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^▁*�+$")

# The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*"
# suffix is required to handle the NorwAI tokenizer.
re_replacement_seq = re.compile(r"^▁*�+\.*$")


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
Expand Down
29 changes: 29 additions & 0 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
reduced_vocabulary,
walk_fsm,
)
from outlines.integrations.utils import adapt_tokenizer
from outlines.models.transformers import TransformerTokenizer


Expand Down Expand Up @@ -686,3 +687,31 @@ def test_numba_leading_null_byte_unicode_type_sane(input_key):
d = numba.typed.typeddict.Dict.empty(numba.types.unicode_type, numba.int64)
d["一"] = 10 # \xe4\xb8\x80
str(d) # assert successfully interprets


@pytest.mark.parametrize(
"rare_token",
[
"�",
"��",
"�.",
"�..",
"▁�",
"▁▁�",
"▁�.",
"▁�.",
"▁▁�..",
],
)
def test_reduced_vocabulary_with_rare_tokens(rare_token):
"""Assert reduced_vocabulary works with rare tokens.
See [1] and [2] for context.
[1]: https://github.com/outlines-dev/outlines/pull/763
[2]: https://github.com/outlines-dev/outlines/pull/948
"""
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer = adapt_tokenizer(tokenizer=tokenizer)
tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1
reduced_vocabulary(tokenizer)

0 comments on commit a987159

Please sign in to comment.