From 0b457eccf7921445b98cb8703e824990ede7afec Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 10 Oct 2024 19:45:47 -0400 Subject: [PATCH] update RegexGuide to conform with outlines-core --- benchmarks/bench_json_schema.py | 2 +- benchmarks/bench_regex_guide.py | 4 +-- outlines/fsm/guide.py | 27 +++++++++++-------- outlines/processors/structured.py | 2 +- tests/fsm/test_guide.py | 10 +++---- tests/generate/test_integration_llamacpp.py | 4 +-- .../generate/test_integration_transformers.py | 4 +-- 7 files changed, 29 insertions(+), 24 deletions(-) diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 8990b015c..62d9b3c1d 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -77,4 +77,4 @@ def time_json_schema_to_regex(self, schema_name): @cache_disabled() def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) - RegexGuide(regex, self.tokenizer) + RegexGuide.from_regex(regex, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 7aaef6bac..fa23a724f 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -25,7 +25,7 @@ def setup(self, pattern_name): @cache_disabled() def time_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) + RegexGuide.from_regex(self.pattern, self.tokenizer) class MemoryRegexGuideBenchmark: @@ -37,4 +37,4 @@ def setup(self, pattern_name): @cache_disabled() def peakmem_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) + RegexGuide.from_regex(self.pattern, self.tokenizer) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 697597234..b66490f81 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -74,8 +74,8 @@ def copy(self): @cache() -def create_states_mapping(regex_string, tokenizer): - return uncached_create_states_mapping(regex_string, tokenizer) +def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs): + return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs) class RegexGuide(CoreRegexGuide): @@ -84,15 +84,20 @@ class RegexGuide(CoreRegexGuide): CoreRegexGuide with outlines cache """ - def __init__(self, regex_string: str, tokenizer: "Tokenizer"): - ( - self.states_to_token_maps, - self.empty_token_ids, - fsm_finals, - ) = create_states_mapping(regex_string, tokenizer) - self.eos_token_id = tokenizer.eos_token_id - self.final_states = fsm_finals | {-1} - self._cache_state_to_token_tensor() + @classmethod + def from_regex( + cls, + regex_string: str, + tokenizer, + _create_states_mapping=cached_create_states_mapping, + **kwargs, + ): + return super().from_regex( + regex_string, + tokenizer, + _create_states_mapping=_create_states_mapping, + **kwargs, + ) CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"]) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index e3b9e60d3..d2bc15f77 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -149,7 +149,7 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"): tokenizer An Outlines tokenizer """ - guide = RegexGuide(regex_string, tokenizer) + guide = RegexGuide.from_regex(regex_string, tokenizer) super().__init__(tokenizer=tokenizer, guide=guide) diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 67b4e0dd8..510faf4b0 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -43,7 +43,7 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" with pytest.raises(ValueError, match="The vocabulary"): - RegexGuide(regex_str, MockTokenizer()) + RegexGuide.from_regex(regex_str, MockTokenizer()) def test_regex(): @@ -57,7 +57,7 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == {0: {1: 1}} @@ -98,7 +98,7 @@ def convert_token_to_string(self, token): regex_str = "[😁-😎]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == { 0: {5: 1, 4: 2}, @@ -145,7 +145,7 @@ def convert_token_to_string(self, token): regex_str = " [😁-😎]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == { 0: {5: 1, 10: 2}, @@ -180,7 +180,7 @@ def convert_token_to_string(self, token): regex_str = r"`\n(\.\n)?`\n" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) state = fsm.get_next_state(state=4, token_id=103) assert state == 5 diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 08521c672..8d4596d60 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -278,7 +278,7 @@ def test_RegexGuide_caching(model, temp_cache_dir): import llama_cpp import outlines.caching - from outlines.fsm.guide import create_states_mapping + from outlines.fsm.guide import cached_create_states_mapping assert outlines.caching._caching_enabled @@ -291,7 +291,7 @@ def test_RegexGuide_caching(model, temp_cache_dir): _ = cache.stats(enable=True) assert cache.statistics - assert create_states_mapping.__memory__ is cache + assert cached_create_states_mapping.__memory__ is cache generator = generate.regex(model, regex, sampler=samplers.greedy()) assert cache.stats() == (0, 1) diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 1d26a9ee4..2462d9fcf 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -494,7 +494,7 @@ def test_transformers_use_existing_model_and_tokenizer(): def test_RegexGuide_caching(temp_cache_dir): import outlines.caching - from outlines.fsm.guide import create_states_mapping + from outlines.fsm.guide import cached_create_states_mapping assert outlines.caching._caching_enabled @@ -507,7 +507,7 @@ def test_RegexGuide_caching(temp_cache_dir): _ = cache.stats(enable=True) assert cache.statistics - assert create_states_mapping.__memory__ is cache + assert cached_create_states_mapping.__memory__ is cache model = models.transformers( "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM", device="cpu"