Skip to content

Commit

Permalink
update RegexGuide to conform with outlines-core
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 10, 2024
1 parent d7569ef commit 0b457ec
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 24 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def time_json_schema_to_regex(self, schema_name):
@cache_disabled()
def time_json_schema_to_fsm(self, schema_name):
regex = build_regex_from_schema(self.schema)
RegexGuide(regex, self.tokenizer)
RegexGuide.from_regex(regex, self.tokenizer)
4 changes: 2 additions & 2 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setup(self, pattern_name):

@cache_disabled()
def time_regex_to_guide(self, pattern_name):
RegexGuide(self.pattern, self.tokenizer)
RegexGuide.from_regex(self.pattern, self.tokenizer)


class MemoryRegexGuideBenchmark:
Expand All @@ -37,4 +37,4 @@ def setup(self, pattern_name):

@cache_disabled()
def peakmem_regex_to_guide(self, pattern_name):
RegexGuide(self.pattern, self.tokenizer)
RegexGuide.from_regex(self.pattern, self.tokenizer)
27 changes: 16 additions & 11 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def copy(self):


@cache()
def create_states_mapping(regex_string, tokenizer):
return uncached_create_states_mapping(regex_string, tokenizer)
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)


class RegexGuide(CoreRegexGuide):
Expand All @@ -84,15 +84,20 @@ class RegexGuide(CoreRegexGuide):
CoreRegexGuide with outlines cache
"""

def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}
self._cache_state_to_token_tensor()
@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer,
_create_states_mapping=cached_create_states_mapping,
**kwargs,
):
return super().from_regex(
regex_string,
tokenizer,
_create_states_mapping=_create_states_mapping,
**kwargs,
)


CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])
Expand Down
2 changes: 1 addition & 1 deletion outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
tokenizer
An Outlines tokenizer
"""
guide = RegexGuide(regex_string, tokenizer)
guide = RegexGuide.from_regex(regex_string, tokenizer)
super().__init__(tokenizer=tokenizer, guide=guide)


Expand Down
10 changes: 5 additions & 5 deletions tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def convert_token_to_string(self, token):
regex_str = "[1-9]"

with pytest.raises(ValueError, match="The vocabulary"):
RegexGuide(regex_str, MockTokenizer())
RegexGuide.from_regex(regex_str, MockTokenizer())


def test_regex():
Expand All @@ -57,7 +57,7 @@ def convert_token_to_string(self, token):

regex_str = "[1-9]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}

Expand Down Expand Up @@ -98,7 +98,7 @@ def convert_token_to_string(self, token):

regex_str = "[😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 4: 2},
Expand Down Expand Up @@ -145,7 +145,7 @@ def convert_token_to_string(self, token):

regex_str = " [😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 10: 2},
Expand Down Expand Up @@ -180,7 +180,7 @@ def convert_token_to_string(self, token):

regex_str = r"`\n(\.\n)?`\n"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

state = fsm.get_next_state(state=4, token_id=103)
assert state == 5
Expand Down
4 changes: 2 additions & 2 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_RegexGuide_caching(model, temp_cache_dir):
import llama_cpp

import outlines.caching
from outlines.fsm.guide import create_states_mapping
from outlines.fsm.guide import cached_create_states_mapping

assert outlines.caching._caching_enabled

Expand All @@ -291,7 +291,7 @@ def test_RegexGuide_caching(model, temp_cache_dir):
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache
assert cached_create_states_mapping.__memory__ is cache

generator = generate.regex(model, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 1)
Expand Down
4 changes: 2 additions & 2 deletions tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_transformers_use_existing_model_and_tokenizer():

def test_RegexGuide_caching(temp_cache_dir):
import outlines.caching
from outlines.fsm.guide import create_states_mapping
from outlines.fsm.guide import cached_create_states_mapping

assert outlines.caching._caching_enabled

Expand All @@ -507,7 +507,7 @@ def test_RegexGuide_caching(temp_cache_dir):
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache
assert cached_create_states_mapping.__memory__ is cache

model = models.transformers(
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM", device="cpu"
Expand Down

0 comments on commit 0b457ec

Please sign in to comment.