diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index daa77510b..8d1ceeb24 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,15 +1,8 @@ -import outlines +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema -outlines.disable_cache() - -from outlines.fsm.guide import RegexGuide # noqa: E402 -from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 - -from .common import ( # noqa: E402 - clear_outlines_cache, - ensure_numba_compiled, - setup_tokenizer, -) +from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 simple_schema = """{ "$defs": { @@ -74,14 +67,15 @@ class JsonSchemaBenchmark: params = schemas.keys() def setup(self, schema_name): - clear_outlines_cache() self.tokenizer = setup_tokenizer() self.schema = schemas[schema_name] ensure_numba_compiled(self.tokenizer) + @cache_disabled() def time_json_schema_to_regex(self, schema_name): build_regex_from_schema(self.schema) + @cache_disabled() def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) RegexGuide(regex, self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py index c0e9d87c4..2713707e5 100644 --- a/benchmarks/bench_numba_compile.py +++ b/benchmarks/bench_numba_compile.py @@ -3,18 +3,14 @@ import interegular import numba -import outlines +from outlines.caching import cache_disabled +from outlines.fsm import regex -from .common import clear_outlines_cache, setup_tokenizer - -outlines.disable_cache() +from .common import setup_tokenizer class NumbaCompileBenchmark: def setup(self): - clear_outlines_cache() - from outlines.fsm import regex - self.tokenizer = setup_tokenizer() self.regex = regex original_njit = numba.njit @@ -33,5 +29,6 @@ def mock_njit(*args, **kwargs): def teardown(self): numba.njit = self.original_njit + @cache_disabled() def time_compile_numba(self): self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index efaea9e1f..099f94df2 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,10 +1,7 @@ -import outlines +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide -from .common import clear_outlines_cache, ensure_numba_compiled, setup_tokenizer - -outlines.disable_cache() - -from outlines.fsm.guide import RegexGuide # noqa: E402 +from .common import ensure_numba_compiled, setup_tokenizer regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -23,11 +20,11 @@ class RegexGuideBenchmark: params = regex_samples.keys() def setup(self, pattern_name): - clear_outlines_cache() self.tokenizer = setup_tokenizer() ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] + @cache_disabled() def time_regex_to_guide(self, pattern_name): RegexGuide(self.pattern, self.tokenizer) @@ -36,10 +33,10 @@ class MemoryRegexGuideBenchmark: params = ["simple_phone", "complex_span_constrained_relation_extraction"] def setup(self, pattern_name): - clear_outlines_cache() self.tokenizer = setup_tokenizer() ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] + @cache_disabled() def peakmem_regex_to_guide(self, pattern_name): RegexGuide(self.pattern, self.tokenizer) diff --git a/benchmarks/common.py b/benchmarks/common.py index e0fe36f14..7d999ea9b 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1,14 +1,9 @@ from transformers import AutoTokenizer -import outlines.caching from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer -def clear_outlines_cache(): - outlines.caching.clear_cache() - - def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) diff --git a/outlines/caching.py b/outlines/caching.py index 52d66af74..95392c7e8 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import functools import os from typing import Callable, Optional @@ -164,3 +165,15 @@ def clear_cache(): """Erase the cache completely.""" memory = get_cache() memory.clear() + + +@contextlib.contextmanager +def cache_disabled(): + # outlines.caching._caching_enabled + global _caching_enabled + original_state = _caching_enabled + _caching_enabled = False + try: + yield + finally: + _caching_enabled = original_state diff --git a/tests/test_cache.py b/tests/test_cache.py index 5a2de778e..eb4ec406e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,6 @@ import os import tempfile +import unittest import diskcache import pytest @@ -157,3 +158,33 @@ def foo(): # assert with version upgrade, old cache is invalidated and new cache is used a, b = foo() + + +def test_cache_disabled_decorator(test_cache): + """Ensure cache can be disabled in a local scope""" + + from outlines.caching import cache_disabled + + mock = unittest.mock.MagicMock() + + @test_cache + def fn(): + mock() + return 1 + + # first call isn't cached + fn() + assert mock.call_count == 1 + + # second call doesn't run fn, uses cache + fn() + assert mock.call_count == 1 + + # cache_disabled decorator disables cache within scope + with cache_disabled(): + fn() + assert mock.call_count == 2 # called once in cache_disabled scope + + # scope has exited, cache is enabled again + fn() + assert mock.call_count == 2