From 43799457cd33eb1cc9ca232875fec994071ce7c5 Mon Sep 17 00:00:00 2001 From: Krukov D Date: Sat, 18 May 2024 23:42:43 +0300 Subject: [PATCH] fix: check setup for disable not configured cache, feat: get_or_set, fix: tags with func on invalidate --- .pre-commit-config.yaml | 6 ----- Readme.md | 15 ++++++----- cashews/_typing.py | 7 ++--- cashews/commands.py | 1 + cashews/contrib/fastapi.py | 4 ++- cashews/formatter.py | 7 +++-- cashews/helpers.py | 12 +++------ cashews/key.py | 2 +- cashews/key_context.py | 11 +++++--- cashews/validation.py | 4 ++- cashews/wrapper/auto_init.py | 6 ++--- cashews/wrapper/commands.py | 21 ++++++++++++++- cashews/wrapper/disable_control.py | 9 ++++--- examples/bug.py | 42 ++++++++++++++++++++++++++++++ examples/keys.py | 2 +- examples/simple.py | 2 ++ pyproject.toml | 1 + tests/test_backend_commands.py | 24 +++++++++++++++++ tests/test_disable_control.py | 5 ++++ tests/test_invalidate.py | 36 +++++++++++++++++++++++++ tests/test_transaction.py | 33 +++++++++++++++++++++++ 21 files changed, 209 insertions(+), 41 deletions(-) create mode 100644 examples/bug.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 10f94e9..e10dbf5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,12 +13,6 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - id: debug-statements -# -# - repo: https://github.com/pre-commit/mirrors-prettier -# rev: v4.0.0-alpha.8 -# hooks: -# - id: prettier -# stages: [commit] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.2 diff --git a/Readme.md b/Readme.md index 915c7c7..ef07548 100644 --- a/Readme.md +++ b/Readme.md @@ -196,16 +196,19 @@ from cashews import cache cache.setup("mem://") # configure as in-memory cache -await cache.set(key="key", value=90, expire=60, exist=None) # -> bool +await cache.set(key="key", value=90, expire="2h", exist=None) # -> bool await cache.set_raw(key="key", value="str") # -> bool +await cache.set_many({"key1": value, "key2": value}) # -> None await cache.get("key", default=None) # -> Any -await cache.get_raw("key") -await cache.get_many("key1", "key2", default=None) +await cache.get_or_set("key", default=awaitable_or_callable, expire="1h") # -> Any +await cache.get_raw("key") # -> Any +await cache.get_many("key1", "key2", default=None) # -> tuple[Any] async for key, value in cache.get_match("pattern:*", batch_size=100): ... await cache.incr("key") # -> int +await cache.exists("key") # -> bool await cache.delete("key") await cache.delete_many("key1", "key2") @@ -928,8 +931,8 @@ E.g. A simple middleware to use it in a web app: async def add_from_cache_headers(request: Request, call_next): with cache.detect as detector: response = await call_next(request) - if detector.keys: - key = list(detector.keys.keys())[0] + if detector.calls: + key = list(detector.calls.keys())[0] response.headers["X-From-Cache"] = key expire = await cache.get_expire(key) response.headers["X-From-Cache-Expire-In-Seconds"] = str(expire) @@ -1004,7 +1007,7 @@ Here we want to have some way to protect our code from race conditions and do op Cashews support transaction operations: -> :warning: \*\*Warning: transaction operations are `set`, `set_many`, `delete`, `delete_many`, `delete_match` and `incr` + > :warning: \*\*Warning: transaction operations are `set`, `set_many`, `delete`, `delete_many`, `delete_match` and `incr` ```python from cashews import cache diff --git a/cashews/_typing.py b/cashews/_typing.py index 0152d2c..70aaee9 100644 --- a/cashews/_typing.py +++ b/cashews/_typing.py @@ -30,8 +30,9 @@ def __call__( CacheCondition = Union[CallableCacheCondition, str, None] -AsyncCallableResult_T = TypeVar("AsyncCallableResult_T") -AsyncCallable_T = Callable[..., Awaitable[AsyncCallableResult_T]] +Result_T = TypeVar("Result_T") +AsyncCallable_T = Callable[..., Awaitable[Result_T]] +Callable_T = Callable[..., Result_T] DecoratedFunc = TypeVar("DecoratedFunc", bound=AsyncCallable_T) @@ -44,7 +45,7 @@ def __call__( backend: Backend, *args, **kwargs, - ) -> Awaitable[AsyncCallableResult_T | None]: # pragma: no cover + ) -> Awaitable[Result_T | None]: # pragma: no cover ... diff --git a/cashews/commands.py b/cashews/commands.py index 730a7f7..43b8d31 100644 --- a/cashews/commands.py +++ b/cashews/commands.py @@ -19,6 +19,7 @@ class Command(Enum): EXPIRE = "expire" GET_EXPIRE = "get_expire" CLEAR = "clear" + SET_LOCK = "set_lock" UNLOCK = "unlock" IS_LOCKED = "is_locked" diff --git a/cashews/contrib/fastapi.py b/cashews/contrib/fastapi.py index a6216f6..9a1fcc2 100644 --- a/cashews/contrib/fastapi.py +++ b/cashews/contrib/fastapi.py @@ -55,10 +55,12 @@ def __init__( cache_instance: Cache = cache, methods: Sequence[str] = ("get",), private=True, + prefix_to_disable: str = "", ): self._private = private self._cache = cache_instance self._methods = methods + self._prefix_to_disable = prefix_to_disable super().__init__(app) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: @@ -68,7 +70,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - return await call_next(request) to_disable = _to_disable(cache_control_value) if to_disable: - context = self._cache.disabling(*to_disable) + context = self._cache.disabling(*to_disable, prefix=self._prefix_to_disable) with context, max_age(cache_control_value), self._cache.detect as detector: response = await call_next(request) calls = detector.calls_list diff --git a/cashews/formatter.py b/cashews/formatter.py index 0df6825..2db2add 100644 --- a/cashews/formatter.py +++ b/cashews/formatter.py @@ -158,8 +158,11 @@ def _upper(value: TemplateValue) -> TemplateValue: def default_format(template: KeyTemplate, **values) -> KeyOrTemplate: - _template_context = key_context.get() - _template_context.update(values) + _template_context, rewrite = key_context.get() + if rewrite: + _template_context = {**values, **_template_context} + else: + _template_context = {**_template_context, **values} return default_formatter.format(template, **_template_context) diff --git a/cashews/helpers.py b/cashews/helpers.py index d6f1ae8..c8e1428 100644 --- a/cashews/helpers.py +++ b/cashews/helpers.py @@ -1,6 +1,6 @@ from typing import Optional -from ._typing import AsyncCallable_T, AsyncCallableResult_T, Middleware +from ._typing import AsyncCallable_T, Middleware, Result_T from .backends.interface import Backend from .commands import PATTERN_CMDS, Command from .key import get_call_values @@ -8,9 +8,7 @@ def add_prefix(prefix: str) -> Middleware: - async def _middleware( - call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs - ) -> AsyncCallableResult_T: + async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T: if cmd == Command.GET_MANY: return await call(*[prefix + key for key in args]) call_values = get_call_values(call, args, kwargs) @@ -29,9 +27,7 @@ async def _middleware( def all_keys_lower() -> Middleware: - async def _middleware( - call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs - ) -> AsyncCallableResult_T: + async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T: if cmd == Command.GET_MANY: return await call(*[key.lower() for key in args]) call_values = get_call_values(call, args, kwargs) @@ -54,7 +50,7 @@ async def _middleware( def memory_limit(min_bytes=0, max_bytes=None) -> Middleware: async def _middleware( call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs - ) -> Optional[AsyncCallableResult_T]: + ) -> Optional[Result_T]: if cmd != Command.SET: return await call(*args, **kwargs) call_values = get_call_values(call, args, kwargs) diff --git a/cashews/key.py b/cashews/key.py index 69a7d12..29f63a5 100644 --- a/cashews/key.py +++ b/cashews/key.py @@ -119,7 +119,7 @@ def _default(name): return "*" check = _ReplaceFormatter(default=_default) - check.format(key, **{**get_key_context(), **func_params}) + check.format(key, **{**get_key_context()[0], **func_params}) if errors: raise WrongKeyError(f"Wrong parameter placeholder '{errors}' in the key ") diff --git a/cashews/key_context.py b/cashews/key_context.py index c6019b2..a2b5f0f 100644 --- a/cashews/key_context.py +++ b/cashews/key_context.py @@ -4,12 +4,14 @@ from contextvars import ContextVar from typing import Any, Iterator -_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context", default={}) +_REWRITE = "__rewrite" +_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context", default={_REWRITE: False}) @contextmanager -def context(**values) -> Iterator[None]: +def context(rewrite=False, **values) -> Iterator[None]: new_context = {**_template_context.get(), **values} + new_context[_REWRITE] = rewrite token = _template_context.set(new_context) try: yield @@ -17,8 +19,9 @@ def context(**values) -> Iterator[None]: _template_context.reset(token) -def get(): - return {**_template_context.get()} +def get() -> tuple[dict[str, Any], bool]: + _context = {**_template_context.get()} + return _context, _context.pop(_REWRITE) def register(*names: str) -> None: diff --git a/cashews/validation.py b/cashews/validation.py index 0c9ee16..8aec9d2 100644 --- a/cashews/validation.py +++ b/cashews/validation.py @@ -8,6 +8,7 @@ from .commands import RETRIEVE_CMDS, Command from .formatter import default_format from .key import get_call_values +from .key_context import context as template_context def invalidate( @@ -29,7 +30,8 @@ async def _wrap(*args, **kwargs): if dest in _args: _args[source] = _args.pop(dest) key = default_format(key_template, **_args) - await backend.delete_match(key) + with template_context(**_args, rewrite=True): + await backend.delete_match(key) return result return _wrap diff --git a/cashews/wrapper/auto_init.py b/cashews/wrapper/auto_init.py index a2cf4a3..a88b21d 100644 --- a/cashews/wrapper/auto_init.py +++ b/cashews/wrapper/auto_init.py @@ -1,6 +1,6 @@ import asyncio -from cashews._typing import AsyncCallable_T, AsyncCallableResult_T, Middleware +from cashews._typing import AsyncCallable_T, Middleware, Result_T from cashews.backends.interface import Backend from cashews.commands import Command @@ -8,9 +8,7 @@ def create_auto_init() -> Middleware: lock = asyncio.Lock() - async def _auto_init( - call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs - ) -> AsyncCallableResult_T: + async def _auto_init(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T: if backend.is_init: return await call(*args, **kwargs) async with lock: diff --git a/cashews/wrapper/commands.py b/cashews/wrapper/commands.py index 2ef6888..4ad076b 100644 --- a/cashews/wrapper/commands.py +++ b/cashews/wrapper/commands.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from functools import partial from typing import TYPE_CHECKING, AsyncIterator, Iterable, Mapping, overload @@ -10,7 +11,9 @@ from .wrapper import Wrapper if TYPE_CHECKING: # pragma: no cover - from cashews._typing import TTL, Default, Key, Value + from cashews._typing import TTL, AsyncCallable_T, Callable_T, Default, Key, Result_T, Value + +_empty = object() class CommandWrapper(Wrapper): @@ -40,6 +43,22 @@ async def get(self, key: Key, default: None = None) -> Value | None: ... async def get(self, key: Key, default: Default | None = None) -> Value | Default | None: return await self._with_middlewares(Command.GET, key)(key=key, default=default) + async def get_or_set( + self, key: Key, default: Default | AsyncCallable_T | Callable_T, expire: TTL = None + ) -> Value | Default | Result_T: + value = await self.get(key, default=_empty) + if value is not _empty: + return value + if callable(default): + if inspect.iscoroutinefunction(default): + _default = await default() + else: + _default = default() + else: + _default = default + await self.set(key, _default, expire=expire) + return default + async def get_raw(self, key: Key) -> Value: return await self._with_middlewares(Command.GET_RAW, key)(key=key) diff --git a/cashews/wrapper/disable_control.py b/cashews/wrapper/disable_control.py index 559f125..a5181f0 100644 --- a/cashews/wrapper/disable_control.py +++ b/cashews/wrapper/disable_control.py @@ -1,9 +1,10 @@ from __future__ import annotations -from contextlib import contextmanager +from contextlib import contextmanager, suppress from typing import TYPE_CHECKING, Iterator from cashews.commands import Command +from cashews.exceptions import NotConfiguredError from .wrapper import Wrapper @@ -26,7 +27,8 @@ def __init__(self, name: str = ""): self.add_middleware(_is_disable_middleware) def disable(self, *cmds: Command, prefix: str = "") -> None: - return self._get_backend(prefix).disable(*cmds) + with suppress(NotConfiguredError): + return self._get_backend(prefix).disable(*cmds) def enable(self, *cmds: Command, prefix: str = "") -> None: return self._get_backend(prefix).enable(*cmds) @@ -37,7 +39,8 @@ def disabling(self, *cmds: Command, prefix: str = "") -> Iterator[None]: try: yield finally: - self.enable(*cmds, prefix=prefix) + with suppress(NotConfiguredError): + self.enable(*cmds, prefix=prefix) def is_disable(self, *cmds: Command, prefix: str = "") -> bool: return self._get_backend(prefix).is_disable(*cmds) diff --git a/examples/bug.py b/examples/bug.py new file mode 100644 index 0000000..4b48e41 --- /dev/null +++ b/examples/bug.py @@ -0,0 +1,42 @@ +import asyncio +from collections.abc import Mapping +from typing import Any + +from cashews import cache, default_formatter + +cache.setup("mem://?size=1000000&check_interval=5") + + +@default_formatter.register("get_item", preformat=False) +def _getitem_func(mapping: Mapping[str, Any], key: str) -> str: + try: + return str(mapping[key]) + except Exception as e: + # when key/tag matching, this may be called with the rendered value + raise RuntimeError(f"{mapping=}, {key=}") from e + + +@cache( + ttl="1h", + key="prefix:keys:{mapping:get_item(bar)}", + tags=["prefix:tags:{mapping:get_item(bar)}"], +) +async def foo(mapping: str) -> None: + print("Foo", mapping) + + +@cache.invalidate("prefix:keys:{mapping:get_item(bar)}") +async def bar(mapping: str) -> None: + print("Bar", mapping) + + +async def main() -> None: + await foo({"bar": "baz"}) + await bar({"bar": "baz"}) + + +if __name__ == "__main__": + asyncio.run(main()) + +# prints Foo {'bar': 'baz'} +# prints Bar {'bar': 'baz'} diff --git a/examples/keys.py b/examples/keys.py index c752cf3..4674a91 100644 --- a/examples/keys.py +++ b/examples/keys.py @@ -56,7 +56,7 @@ async def _call(function, *args, **kwargs): await function(*args, **kwargs) with cache.detect as detector: await function(*args, **kwargs) - key = list(detector.keys.keys())[-1] + key = list(detector.calls.keys())[-1] print( f""" diff --git a/examples/simple.py b/examples/simple.py index 864a156..a1faf49 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -18,6 +18,8 @@ async def basic(): await cache.set("key", 1) assert await cache.get("key") == 1 await cache.set("key1", value={"any": True}, expire="1m") + print(await cache.get_or_set("key200", default=lambda: "test")) + print(await cache.get_or_set("key10", default="test")) await cache.set_many({"key2": "test", "key3": Decimal("10.1")}, expire="1m") print("Get: ", await cache.get("key1")) # -> Any diff --git a/pyproject.toml b/pyproject.toml index 571f2b1..5925839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ line-length = 119 [tool.ruff.lint] select = ["E", "F", "B", "I", "SIM", "UP", "C4"] +ignore = ["SIM108"] [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = [ diff --git a/tests/test_backend_commands.py b/tests/test_backend_commands.py index 35ae6e9..92cbc46 100644 --- a/tests/test_backend_commands.py +++ b/tests/test_backend_commands.py @@ -15,6 +15,30 @@ async def test_set_get(cache: Cache): assert await cache.get("key") == VALUE +async def test_get_or_set(cache: Cache): + await cache.get_or_set("key", default=VALUE) + assert await cache.get("key") == VALUE + + +async def test_get_or_set_callable(cache: Cache): + await cache.get_or_set("key", default=lambda: VALUE) + assert await cache.get("key") == VALUE + + +async def test_get_or_set_awaitable(cache: Cache): + async def _default(): + return VALUE + + await cache.get_or_set("key", default=_default) + assert await cache.get("key") == VALUE + + +async def test_get_or_set_no_set(cache: Cache): + await cache.set("key", VALUE) + await cache.get_or_set("key", None) + assert await cache.get("key") == VALUE + + async def test_set_get_bytes(cache: Cache): await cache.set("key", b"10") assert await cache.get("key") == b"10" diff --git a/tests/test_disable_control.py b/tests/test_disable_control.py index 8cb6240..7f4f15f 100644 --- a/tests/test_disable_control.py +++ b/tests/test_disable_control.py @@ -65,6 +65,11 @@ async def test_disable_context_manage_get(cache): assert await cache.get("test") is None +def test_disable_context_manage_no_init(): + cache = Cache() + cache.disable(Command.GET) + + async def test_disable_context_manage_decor(cache): @cache(ttl="1m") async def func(): diff --git a/tests/test_invalidate.py b/tests/test_invalidate.py index 70657d3..6bbc06c 100644 --- a/tests/test_invalidate.py +++ b/tests/test_invalidate.py @@ -23,6 +23,42 @@ async def func2(arg, default=True): assert first_call != await func("test") +async def test_invalidate_decor_tag(cache: Cache): + @cache(ttl=1, key="key:{arg}", tags=("key:{arg}", "key")) + async def func(arg): + return random.random() + + @cache.invalidate("key:{arg}") + async def func2(arg, default=True): + return random.random() + + first_call = await func("test") + await asyncio.sleep(0.01) + + assert first_call == await func("test") + await func2("test") + await asyncio.sleep(0.01) + assert first_call != await func("test") + + +async def test_invalidate_decor_tag_func(cache: Cache): + @cache(ttl=1, key="key:{arg:hash}", tags=("key:{arg}", "key")) + async def func(arg): + return random.random() + + @cache.invalidate("key:{arg:hash}") + async def func2(arg, default=True): + return random.random() + + first_call = await func("test") + await asyncio.sleep(0.01) + + assert first_call == await func("test") + await func2("test") + await asyncio.sleep(0.01) + assert first_call != await func("test") + + async def test_invalidate_further_decorator(cache): @cache(ttl=100) async def func(): diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 4e3cd8a..6334ab6 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -79,6 +79,39 @@ async def test_transaction_set_exception(cache: Cache, tx_mode): assert await cache.get("key2") == "value2" +async def test_transaction_get_or_set(cache: Cache, tx_mode): + await cache.set("key1", "value1", expire=1) + + async with cache.transaction(tx_mode): + await cache.get_or_set("key", "value") + await cache.get_or_set("key1", "value2") + + assert await cache.get("key") == "value" + assert await cache.get("key1") == "value1" + + assert await cache.get("key") == "value" + assert await cache.get("key1") == "value1" + + +async def test_transaction_get_or_set_rollback(cache: Cache, tx_mode): + await cache.set("key1", "value1", expire=1) + + async with cache.transaction(tx_mode) as tx: + await cache.get_or_set("key", "value") + await cache.get_or_set("key1", "value2") + + assert await cache.get("key") == "value" + assert await cache.get("key1") == "value1" + + await tx.rollback() + + assert await cache.get("key") is None + assert await cache.get("key1") == "value1" + + assert await cache.get("key") is None + assert await cache.get("key1") == "value1" + + async def test_transaction_set_delete_get_many(cache: Cache, tx_mode): await cache.set("key1", "value", expire=1) await cache.set("key2", "value2", expire=2)