Skip to content

Commit

Permalink
disable outlines cache localized to the benchmarks scope
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 4, 2024
1 parent b7ab26c commit 5a6e154
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 32 deletions.
18 changes: 6 additions & 12 deletions benchmarks/bench_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import outlines
from outlines.caching import cache_disabled
from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema

outlines.disable_cache()

from outlines.fsm.guide import RegexGuide # noqa: E402
from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402

from .common import ( # noqa: E402
clear_outlines_cache,
ensure_numba_compiled,
setup_tokenizer,
)
from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402

simple_schema = """{
"$defs": {
Expand Down Expand Up @@ -74,14 +67,15 @@ class JsonSchemaBenchmark:
params = schemas.keys()

def setup(self, schema_name):
clear_outlines_cache()
self.tokenizer = setup_tokenizer()
self.schema = schemas[schema_name]
ensure_numba_compiled(self.tokenizer)

@cache_disabled()
def time_json_schema_to_regex(self, schema_name):
build_regex_from_schema(self.schema)

@cache_disabled()
def time_json_schema_to_fsm(self, schema_name):
regex = build_regex_from_schema(self.schema)
RegexGuide(regex, self.tokenizer)
11 changes: 4 additions & 7 deletions benchmarks/bench_numba_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import interegular
import numba

import outlines
from outlines.caching import cache_disabled
from outlines.fsm import regex

from .common import clear_outlines_cache, setup_tokenizer

outlines.disable_cache()
from .common import setup_tokenizer


class NumbaCompileBenchmark:
def setup(self):
clear_outlines_cache()
from outlines.fsm import regex

self.tokenizer = setup_tokenizer()
self.regex = regex
original_njit = numba.njit
Expand All @@ -33,5 +29,6 @@ def mock_njit(*args, **kwargs):
def teardown(self):
numba.njit = self.original_njit

@cache_disabled()
def time_compile_numba(self):
self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer)
13 changes: 5 additions & 8 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import outlines
from outlines.caching import cache_disabled
from outlines.fsm.guide import RegexGuide

from .common import clear_outlines_cache, ensure_numba_compiled, setup_tokenizer

outlines.disable_cache()

from outlines.fsm.guide import RegexGuide # noqa: E402
from .common import ensure_numba_compiled, setup_tokenizer

regex_samples = {
"email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?",
Expand All @@ -23,11 +20,11 @@ class RegexGuideBenchmark:
params = regex_samples.keys()

def setup(self, pattern_name):
clear_outlines_cache()
self.tokenizer = setup_tokenizer()
ensure_numba_compiled(self.tokenizer)
self.pattern = regex_samples[pattern_name]

@cache_disabled()
def time_regex_to_guide(self, pattern_name):
RegexGuide(self.pattern, self.tokenizer)

Expand All @@ -36,10 +33,10 @@ class MemoryRegexGuideBenchmark:
params = ["simple_phone", "complex_span_constrained_relation_extraction"]

def setup(self, pattern_name):
clear_outlines_cache()
self.tokenizer = setup_tokenizer()
ensure_numba_compiled(self.tokenizer)
self.pattern = regex_samples[pattern_name]

@cache_disabled()
def peakmem_regex_to_guide(self, pattern_name):
RegexGuide(self.pattern, self.tokenizer)
5 changes: 0 additions & 5 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from transformers import AutoTokenizer

import outlines.caching
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer


def clear_outlines_cache():
outlines.caching.clear_cache()


def setup_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
return TransformerTokenizer(tokenizer)
Expand Down
13 changes: 13 additions & 0 deletions outlines/caching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import functools
import os
from typing import Callable, Optional
Expand Down Expand Up @@ -164,3 +165,15 @@ def clear_cache():
"""Erase the cache completely."""
memory = get_cache()
memory.clear()


@contextlib.contextmanager
def cache_disabled():
# outlines.caching._caching_enabled
global _caching_enabled
original_state = _caching_enabled
_caching_enabled = False
try:
yield
finally:
_caching_enabled = original_state
31 changes: 31 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
import unittest

import diskcache
import pytest
Expand Down Expand Up @@ -157,3 +158,33 @@ def foo():

# assert with version upgrade, old cache is invalidated and new cache is used
a, b = foo()


def test_cache_disabled_decorator(test_cache):
"""Ensure cache can be disabled in a local scope"""

from outlines.caching import cache_disabled

mock = unittest.mock.MagicMock()

@test_cache
def fn():
mock()
return 1

# first call isn't cached
fn()
assert mock.call_count == 1

# second call doesn't run fn, uses cache
fn()
assert mock.call_count == 1

# cache_disabled decorator disables cache within scope
with cache_disabled():
fn()
assert mock.call_count == 2 # called once in cache_disabled scope

# scope has exited, cache is enabled again
fn()
assert mock.call_count == 2

0 comments on commit 5a6e154

Please sign in to comment.