Skip to content

Commit

Permalink
fix: check setup for disable not configured cache, feat: get_or_set, …
Browse files Browse the repository at this point in the history
…fix: tags with func on invalidate
  • Loading branch information
Krukov committed May 18, 2024
1 parent 65c6ce8 commit 4379945
Show file tree
Hide file tree
Showing 21 changed files with 209 additions and 41 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: debug-statements
#
# - repo: https://github.com/pre-commit/mirrors-prettier
# rev: v4.0.0-alpha.8
# hooks:
# - id: prettier
# stages: [commit]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
Expand Down
15 changes: 9 additions & 6 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,19 @@ from cashews import cache

cache.setup("mem://") # configure as in-memory cache

await cache.set(key="key", value=90, expire=60, exist=None) # -> bool
await cache.set(key="key", value=90, expire="2h", exist=None) # -> bool
await cache.set_raw(key="key", value="str") # -> bool
await cache.set_many({"key1": value, "key2": value}) # -> None

await cache.get("key", default=None) # -> Any
await cache.get_raw("key")
await cache.get_many("key1", "key2", default=None)
await cache.get_or_set("key", default=awaitable_or_callable, expire="1h") # -> Any
await cache.get_raw("key") # -> Any
await cache.get_many("key1", "key2", default=None) # -> tuple[Any]
async for key, value in cache.get_match("pattern:*", batch_size=100):
...

await cache.incr("key") # -> int
await cache.exists("key") # -> bool

await cache.delete("key")
await cache.delete_many("key1", "key2")
Expand Down Expand Up @@ -928,8 +931,8 @@ E.g. A simple middleware to use it in a web app:
async def add_from_cache_headers(request: Request, call_next):
with cache.detect as detector:
response = await call_next(request)
if detector.keys:
key = list(detector.keys.keys())[0]
if detector.calls:
key = list(detector.calls.keys())[0]
response.headers["X-From-Cache"] = key
expire = await cache.get_expire(key)
response.headers["X-From-Cache-Expire-In-Seconds"] = str(expire)
Expand Down Expand Up @@ -1004,7 +1007,7 @@ Here we want to have some way to protect our code from race conditions and do op

Cashews support transaction operations:

> :warning: \*\*Warning: transaction operations are `set`, `set_many`, `delete`, `delete_many`, `delete_match` and `incr`
> :warning: \*\*Warning: transaction operations are `set`, `set_many`, `delete`, `delete_many`, `delete_match` and `incr`
```python
from cashews import cache
Expand Down
7 changes: 4 additions & 3 deletions cashews/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def __call__(

CacheCondition = Union[CallableCacheCondition, str, None]

AsyncCallableResult_T = TypeVar("AsyncCallableResult_T")
AsyncCallable_T = Callable[..., Awaitable[AsyncCallableResult_T]]
Result_T = TypeVar("Result_T")
AsyncCallable_T = Callable[..., Awaitable[Result_T]]
Callable_T = Callable[..., Result_T]

DecoratedFunc = TypeVar("DecoratedFunc", bound=AsyncCallable_T)

Expand All @@ -44,7 +45,7 @@ def __call__(
backend: Backend,
*args,
**kwargs,
) -> Awaitable[AsyncCallableResult_T | None]: # pragma: no cover
) -> Awaitable[Result_T | None]: # pragma: no cover
...


Expand Down
1 change: 1 addition & 0 deletions cashews/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Command(Enum):
EXPIRE = "expire"
GET_EXPIRE = "get_expire"
CLEAR = "clear"

SET_LOCK = "set_lock"
UNLOCK = "unlock"
IS_LOCKED = "is_locked"
Expand Down
4 changes: 3 additions & 1 deletion cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def __init__(
cache_instance: Cache = cache,
methods: Sequence[str] = ("get",),
private=True,
prefix_to_disable: str = "",
):
self._private = private
self._cache = cache_instance
self._methods = methods
self._prefix_to_disable = prefix_to_disable
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
Expand All @@ -68,7 +70,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
return await call_next(request)
to_disable = _to_disable(cache_control_value)
if to_disable:
context = self._cache.disabling(*to_disable)
context = self._cache.disabling(*to_disable, prefix=self._prefix_to_disable)
with context, max_age(cache_control_value), self._cache.detect as detector:
response = await call_next(request)
calls = detector.calls_list
Expand Down
7 changes: 5 additions & 2 deletions cashews/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ def _upper(value: TemplateValue) -> TemplateValue:


def default_format(template: KeyTemplate, **values) -> KeyOrTemplate:
_template_context = key_context.get()
_template_context.update(values)
_template_context, rewrite = key_context.get()
if rewrite:
_template_context = {**values, **_template_context}
else:
_template_context = {**_template_context, **values}
return default_formatter.format(template, **_template_context)


Expand Down
12 changes: 4 additions & 8 deletions cashews/helpers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Optional

from ._typing import AsyncCallable_T, AsyncCallableResult_T, Middleware
from ._typing import AsyncCallable_T, Middleware, Result_T
from .backends.interface import Backend
from .commands import PATTERN_CMDS, Command
from .key import get_call_values
from .utils import get_obj_size


def add_prefix(prefix: str) -> Middleware:
async def _middleware(
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
) -> AsyncCallableResult_T:
async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T:
if cmd == Command.GET_MANY:
return await call(*[prefix + key for key in args])
call_values = get_call_values(call, args, kwargs)
Expand All @@ -29,9 +27,7 @@ async def _middleware(


def all_keys_lower() -> Middleware:
async def _middleware(
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
) -> AsyncCallableResult_T:
async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T:
if cmd == Command.GET_MANY:
return await call(*[key.lower() for key in args])
call_values = get_call_values(call, args, kwargs)
Expand All @@ -54,7 +50,7 @@ async def _middleware(
def memory_limit(min_bytes=0, max_bytes=None) -> Middleware:
async def _middleware(
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
) -> Optional[AsyncCallableResult_T]:
) -> Optional[Result_T]:
if cmd != Command.SET:
return await call(*args, **kwargs)
call_values = get_call_values(call, args, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion cashews/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _default(name):
return "*"

check = _ReplaceFormatter(default=_default)
check.format(key, **{**get_key_context(), **func_params})
check.format(key, **{**get_key_context()[0], **func_params})
if errors:
raise WrongKeyError(f"Wrong parameter placeholder '{errors}' in the key ")

Expand Down
11 changes: 7 additions & 4 deletions cashews/key_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@
from contextvars import ContextVar
from typing import Any, Iterator

_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context", default={})
_REWRITE = "__rewrite"
_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context", default={_REWRITE: False})


@contextmanager
def context(**values) -> Iterator[None]:
def context(rewrite=False, **values) -> Iterator[None]:
new_context = {**_template_context.get(), **values}
new_context[_REWRITE] = rewrite
token = _template_context.set(new_context)
try:
yield
finally:
_template_context.reset(token)


def get():
return {**_template_context.get()}
def get() -> tuple[dict[str, Any], bool]:
_context = {**_template_context.get()}
return _context, _context.pop(_REWRITE)


def register(*names: str) -> None:
Expand Down
4 changes: 3 additions & 1 deletion cashews/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .commands import RETRIEVE_CMDS, Command
from .formatter import default_format
from .key import get_call_values
from .key_context import context as template_context


def invalidate(
Expand All @@ -29,7 +30,8 @@ async def _wrap(*args, **kwargs):
if dest in _args:
_args[source] = _args.pop(dest)
key = default_format(key_template, **_args)
await backend.delete_match(key)
with template_context(**_args, rewrite=True):
await backend.delete_match(key)
return result

return _wrap
Expand Down
6 changes: 2 additions & 4 deletions cashews/wrapper/auto_init.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio

from cashews._typing import AsyncCallable_T, AsyncCallableResult_T, Middleware
from cashews._typing import AsyncCallable_T, Middleware, Result_T
from cashews.backends.interface import Backend
from cashews.commands import Command


def create_auto_init() -> Middleware:
lock = asyncio.Lock()

async def _auto_init(
call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs
) -> AsyncCallableResult_T:
async def _auto_init(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T:
if backend.is_init:
return await call(*args, **kwargs)
async with lock:
Expand Down
21 changes: 20 additions & 1 deletion cashews/wrapper/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from functools import partial
from typing import TYPE_CHECKING, AsyncIterator, Iterable, Mapping, overload

Expand All @@ -10,7 +11,9 @@
from .wrapper import Wrapper

if TYPE_CHECKING: # pragma: no cover
from cashews._typing import TTL, Default, Key, Value
from cashews._typing import TTL, AsyncCallable_T, Callable_T, Default, Key, Result_T, Value

_empty = object()


class CommandWrapper(Wrapper):
Expand Down Expand Up @@ -40,6 +43,22 @@ async def get(self, key: Key, default: None = None) -> Value | None: ...
async def get(self, key: Key, default: Default | None = None) -> Value | Default | None:
return await self._with_middlewares(Command.GET, key)(key=key, default=default)

async def get_or_set(
self, key: Key, default: Default | AsyncCallable_T | Callable_T, expire: TTL = None
) -> Value | Default | Result_T:
value = await self.get(key, default=_empty)
if value is not _empty:
return value
if callable(default):
if inspect.iscoroutinefunction(default):
_default = await default()
else:
_default = default()
else:
_default = default
await self.set(key, _default, expire=expire)
return default

async def get_raw(self, key: Key) -> Value:
return await self._with_middlewares(Command.GET_RAW, key)(key=key)

Expand Down
9 changes: 6 additions & 3 deletions cashews/wrapper/disable_control.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from contextlib import contextmanager
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Iterator

from cashews.commands import Command
from cashews.exceptions import NotConfiguredError

from .wrapper import Wrapper

Expand All @@ -26,7 +27,8 @@ def __init__(self, name: str = ""):
self.add_middleware(_is_disable_middleware)

def disable(self, *cmds: Command, prefix: str = "") -> None:
return self._get_backend(prefix).disable(*cmds)
with suppress(NotConfiguredError):
return self._get_backend(prefix).disable(*cmds)

def enable(self, *cmds: Command, prefix: str = "") -> None:
return self._get_backend(prefix).enable(*cmds)
Expand All @@ -37,7 +39,8 @@ def disabling(self, *cmds: Command, prefix: str = "") -> Iterator[None]:
try:
yield
finally:
self.enable(*cmds, prefix=prefix)
with suppress(NotConfiguredError):
self.enable(*cmds, prefix=prefix)

def is_disable(self, *cmds: Command, prefix: str = "") -> bool:
return self._get_backend(prefix).is_disable(*cmds)
Expand Down
42 changes: 42 additions & 0 deletions examples/bug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
from collections.abc import Mapping
from typing import Any

from cashews import cache, default_formatter

cache.setup("mem://?size=1000000&check_interval=5")


@default_formatter.register("get_item", preformat=False)
def _getitem_func(mapping: Mapping[str, Any], key: str) -> str:
try:
return str(mapping[key])
except Exception as e:
# when key/tag matching, this may be called with the rendered value
raise RuntimeError(f"{mapping=}, {key=}") from e


@cache(
ttl="1h",
key="prefix:keys:{mapping:get_item(bar)}",
tags=["prefix:tags:{mapping:get_item(bar)}"],
)
async def foo(mapping: str) -> None:
print("Foo", mapping)


@cache.invalidate("prefix:keys:{mapping:get_item(bar)}")
async def bar(mapping: str) -> None:
print("Bar", mapping)


async def main() -> None:
await foo({"bar": "baz"})
await bar({"bar": "baz"})


if __name__ == "__main__":
asyncio.run(main())

# prints Foo {'bar': 'baz'}
# prints Bar {'bar': 'baz'}
2 changes: 1 addition & 1 deletion examples/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def _call(function, *args, **kwargs):
await function(*args, **kwargs)
with cache.detect as detector:
await function(*args, **kwargs)
key = list(detector.keys.keys())[-1]
key = list(detector.calls.keys())[-1]

print(
f"""
Expand Down
2 changes: 2 additions & 0 deletions examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ async def basic():
await cache.set("key", 1)
assert await cache.get("key") == 1
await cache.set("key1", value={"any": True}, expire="1m")
print(await cache.get_or_set("key200", default=lambda: "test"))
print(await cache.get_or_set("key10", default="test"))

await cache.set_many({"key2": "test", "key3": Decimal("10.1")}, expire="1m")
print("Get: ", await cache.get("key1")) # -> Any
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ line-length = 119

[tool.ruff.lint]
select = ["E", "F", "B", "I", "SIM", "UP", "C4"]
ignore = ["SIM108"]

[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
Expand Down
Loading

0 comments on commit 4379945

Please sign in to comment.