Skip to content

Commit

Permalink
feat: add ability to inject values into execute
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Oct 8, 2021
1 parent 7493a1f commit d55656e
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 56 deletions.
4 changes: 2 additions & 2 deletions di/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from di.container import Container
from di.dependant import Dependant
from di.dependant import Dependant, UnwiredDependant
from di.params import Depends

__all__ = ("Container", "Dependant", "Depends")
__all__ = ("Container", "Dependant", "Depends", "UnwiredDependant")
2 changes: 1 addition & 1 deletion di/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def inner(*args: Any, **kwargs: Any) -> T:

def callable_in_thread_pool(call: Callable[..., T]) -> Callable[..., Awaitable[T]]:
def inner(*args: Any, **kwargs: Any) -> Awaitable[T]:
return anyio.to_thread.run_sync(curry_context(call)) # type: ignore
return anyio.to_thread.run_sync(curry_context(call), *args, **kwargs) # type: ignore

return inner

Expand Down
21 changes: 17 additions & 4 deletions di/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from di._state import ContainerState
from di.exceptions import IncompatibleDependencyError
from di.types.dependencies import DependantProtocol, DependencyParameter
from di.types.executor import Values
from di.types.providers import (
AsyncGeneratorProvider,
CallableProvider,
Expand Down Expand Up @@ -44,13 +45,18 @@ def _gather_params(

class AsyncTask(Task[DependencyType]):
async def compute(
self, state: ContainerState, results: Dict[Task[Any], Any]
self,
state: ContainerState,
results: Dict[Task[Any], Any],
values: Values,
) -> None:
assert self.dependant.call is not None
call = self.dependant.call
args, kwargs = self._gather_params(results)

if self.dependant.share and state.cached_values.contains(self.dependant.call):
if self.dependant.call in values:
results[self] = values[self.dependant.call]
elif self.dependant.share and state.cached_values.contains(self.dependant.call):
# use cached value
results[self] = state.cached_values.get(self.dependant.call)
else:
Expand Down Expand Up @@ -78,12 +84,19 @@ async def compute(


class SyncTask(Task[DependencyType]):
def compute(self, state: ContainerState, results: Dict[Task[Any], Any]) -> None:
def compute(
self,
state: ContainerState,
results: Dict[Task[Any], Any],
values: Values,
) -> None:
assert self.dependant.call is not None
call = self.dependant.call
args, kwargs = self._gather_params(results)

if self.dependant.share and state.cached_values.contains(self.dependant.call):
if call in values:
results[self] = values[call]
elif self.dependant.share and state.cached_values.contains(self.dependant.call):
# use cached value
results[self] = state.cached_values.get(self.dependant.call)
else:
Expand Down
12 changes: 9 additions & 3 deletions di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from di.executors import DefaultExecutor
from di.types import FusedContextManager
from di.types.dependencies import DependantProtocol, DependencyParameter
from di.types.executor import AsyncExecutor, SyncExecutor
from di.types.executor import AsyncExecutor, SyncExecutor, Values
from di.types.providers import (
Dependency,
DependencyProvider,
Expand Down Expand Up @@ -271,14 +271,17 @@ def check_scope(dep: DependantProtocol[Any]) -> None:
def execute_sync(
self,
solved: SolvedDependency[DependencyType],
*,
validate_scopes: bool = True,
values: Optional[Values] = None,
) -> DependencyType:
"""Execute an already solved dependency.
If you are not dynamically changing scopes, you can run once with `validate_scopes=True`
and then disable scope validation in subsequent runs with `validate_scope=False`.
"""
results: Dict[Task[Any], Any] = {}
values = values or {}
with self.enter_local_scope(self._default_scope):
if validate_scopes:
self._validate_scopes(solved)
Expand All @@ -295,19 +298,22 @@ def execute_sync(
[functools.partial(t.compute, self._state, results) for t in group]
for group in tasks
]
return executor.execute_sync(stateful_tasks, lambda: results[tasks[-1][0]]) # type: ignore
return executor.execute_sync(stateful_tasks, lambda: results[tasks[-1][0]], values) # type: ignore

async def execute_async(
self,
solved: SolvedDependency[DependencyType],
*,
validate_scopes: bool = True,
values: Optional[Values] = None,
) -> DependencyType:
"""Execute an already solved dependency.
If you are not dynamically changing scopes, you can run once with `validate_scopes=True`
and then disable scope validation in subsequent runs with `validate_scope=False`.
"""
results: Dict[Task[Any], Any] = {}
values = values or {}
async with self.enter_local_scope(self._default_scope):
if validate_scopes:
self._validate_scopes(solved)
Expand All @@ -326,4 +332,4 @@ async def execute_async(
[functools.partial(t.compute, self._state, results) for t in group]
for group in tasks
]
return await executor.execute_async(stateful_tasks, lambda: results[tasks[-1][0]]) # type: ignore
return await executor.execute_async(stateful_tasks, lambda: results[tasks[-1][0]], values) # type: ignore
10 changes: 10 additions & 0 deletions di/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AsyncGeneratorProvider,
CallableProvider,
CoroutineProvider,
Dependency,
DependencyProvider,
DependencyProviderType,
DependencyType,
Expand Down Expand Up @@ -172,3 +173,12 @@ def create_sub_dependant(
It is recommended to transfer `scope` and possibly `share` to sub-dependencies created in this manner.
"""
return Dependant[Any](call=call, scope=scope, share=share)


class UnwiredDependant(Dependant[Dependency]):
"""A Dependant that does not autowire"""

def get_dependencies(
self,
) -> Dict[str, DependencyParameter[DependantProtocol[Any]]]:
return {}
15 changes: 10 additions & 5 deletions di/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import anyio.abc

from di._concurrency import curry_context, gurantee_awaitable
from di.types.executor import AsyncExecutor, SyncExecutor, Task
from di.types.executor import AsyncExecutor, SyncExecutor, Task, Values

ResultType = typing.TypeVar("ResultType")

Expand All @@ -19,6 +19,7 @@ def execute_sync(
self,
tasks: typing.List[typing.List[Task]],
get_result: typing.Callable[[], ResultType],
values: Values,
) -> ResultType:
for task_group in tasks:
if len(task_group) > 1:
Expand All @@ -28,7 +29,7 @@ def execute_sync(
]
] = []
for task in task_group:
futures.append(self._threadpool.submit(curry_context(task)))
futures.append(self._threadpool.submit(curry_context(task), values))
for future in concurrent.futures.as_completed(futures):
exc = future.exception()
if exc is not None:
Expand All @@ -38,7 +39,7 @@ def execute_sync(
"Cannot execute async dependencies in execute_sync"
)
else:
v = task_group[0]()
v = task_group[0](values)
if inspect.isawaitable(v):
raise TypeError("Cannot execute async dependencies in execute_sync")
return get_result()
Expand All @@ -47,15 +48,19 @@ async def execute_async(
self,
tasks: typing.List[typing.List[Task]],
get_result: typing.Callable[[], ResultType],
values: Values,
) -> ResultType:
# note: there are 2 task group concepts in this function that should not be confused
# to di, tasks groups are a set of Task's that can be executed in parallel
# to anyio, a TaskGroup is a primitive equivalent to a Trio nursery
tg: typing.Optional[anyio.abc.TaskGroup] = None
for task_group in tasks:
if len(task_group) > 1:
if tg is None:
tg = anyio.create_task_group()
async with tg:
for task in task_group:
tg.start_soon(gurantee_awaitable(task)) # type: ignore
tg.start_soon(gurantee_awaitable(task), values) # type: ignore
else:
await gurantee_awaitable(next(iter(task_group)))()
await gurantee_awaitable(next(iter(task_group)))(values)
return get_result()
10 changes: 5 additions & 5 deletions di/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def Depends(
call: Optional[AsyncGeneratorProvider[DependencyType]] = None,
*,
scope: Optional[Scope] = None,
scope: Scope = None,
share: bool = True
) -> DependencyType:
...
Expand All @@ -29,7 +29,7 @@ def Depends(
def Depends(
call: Optional[CoroutineProvider[DependencyType]] = None,
*,
scope: Optional[Scope] = None,
scope: Scope = None,
share: bool = True
) -> DependencyType:
...
Expand All @@ -39,7 +39,7 @@ def Depends(
def Depends(
call: Optional[GeneratorProvider[DependencyType]] = None,
*,
scope: Optional[Scope] = None,
scope: Scope = None,
share: bool = True
) -> DependencyType:
...
Expand All @@ -49,7 +49,7 @@ def Depends(
def Depends(
call: Optional[CallableProvider[DependencyType]] = None,
*,
scope: Optional[Scope] = None,
scope: Scope = None,
share: bool = True
) -> DependencyType:
...
Expand All @@ -58,7 +58,7 @@ def Depends(
def Depends(
call: Optional[DependencyProviderType[DependencyType]] = None,
*,
scope: Optional[Scope] = None,
scope: Scope = None,
share: bool = True
) -> DependencyType:
return Dependant(call=call, scope=scope, share=share) # type: ignore
10 changes: 8 additions & 2 deletions di/types/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
from typing import Awaitable, Callable, List, TypeVar, Union
from typing import Awaitable, Callable, List, Mapping, TypeVar, Union

from di.types.providers import Dependency, DependencyProvider

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand All @@ -8,14 +10,17 @@

ResultType = TypeVar("ResultType")

Task = Callable[[], Union[None, Awaitable[None]]]
Values = Mapping[DependencyProvider, Dependency]

Task = Callable[[Values], Union[None, Awaitable[None]]]


class SyncExecutor(Protocol):
def execute_sync(
self,
tasks: List[List[Task]],
get_result: Callable[[], ResultType],
values: Values,
) -> ResultType:
raise NotImplementedError

Expand All @@ -25,5 +30,6 @@ async def execute_async(
self,
tasks: List[List[Task]],
get_result: Callable[[], ResultType],
values: Values,
) -> ResultType:
raise NotImplementedError
18 changes: 3 additions & 15 deletions docs/src/solved_dependant.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from contextvars import ContextVar
from typing import List, TypeVar

import anyio

from di import Container, Dependant
from di import Container, Dependant, UnwiredDependant
from di.types.solved import SolvedDependency

T = TypeVar("T")
Expand All @@ -17,29 +16,18 @@ class RequestLog(List[Request]):
...


request_ctx: ContextVar[Request] = ContextVar("request_ctx")


def get_request() -> Request:
return request_ctx.get()


async def execute_request(
request: Request, container: Container, solved: SolvedDependency[T]
) -> T:
async with container.enter_local_scope("request"):
token = request_ctx.set(request)
try:
return await container.execute_async(solved)
finally:
request_ctx.reset(token)
return await container.execute_async(solved, values={Request: request})


async def framework() -> None:
container = Container()
request_log = RequestLog()
container.bind(Dependant(lambda: request_log, scope="app"), RequestLog)
container.bind(Dependant(get_request, scope="request"), Request)
container.bind(UnwiredDependant(Request, scope="request"), Request)
solved = container.solve(Dependant(controller, scope="request"))
async with container.enter_global_scope("app"):
# simulate concurrent requests
Expand Down
16 changes: 5 additions & 11 deletions docs/src/starlette/src.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from __future__ import annotations

import contextlib
import contextvars
from typing import Any, Callable

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Route

from di import Container, Dependant

_req: contextvars.ContextVar[Request] = contextvars.ContextVar("req")


def get_request() -> Request:
return _req.get()
from di import Container, Dependant, UnwiredDependant


class WiredGetRoute(Route):
Expand All @@ -29,16 +22,17 @@ def __init__(
solved_endpoint = container.solve(Dependant(endpoint))

async def wrapped_endpoint(request: Request) -> Any:
_req.set(request)
return await container.execute_async(solved_endpoint)
return await container.execute_async(
solved_endpoint, values={Request: request}
)

super().__init__(path=path, endpoint=wrapped_endpoint, methods=["GET"]) # type: ignore


class App(Starlette):
def __init__(self, container: Container | None = None, **kwargs: Any) -> None:
self.container = container or Container()
self.container.bind(Dependant(get_request), Request)
self.container.bind(UnwiredDependant(Request), Request)

@contextlib.asynccontextmanager
async def lifespan(app: App):
Expand Down
8 changes: 3 additions & 5 deletions docs/wiring.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@ For example, `di` will let you "solve" a dependency into a DAG (directed acyclic
You can then provide this solved dependency back to `di` and it will execute it _without any introspection or reflection_.
This means that if you are not dynamically changing your dependency graph, you incurr basically no cost for autowiring.

For example, here is a more advanced use case where the framework takes control of binding the request itself while still allowing the `di` to controll the rest of the dependencies.
We achieve this by:

1. Binding a static function to provide the current request instance.
1. Using an external method (in this case, [convetxvars]) to inject this instance.
For example, here is a more advanced use case where the framework solves the endpoint and then provides the `Request` as a value each time the endpoint is called.

This means that `di` does *not* do any reflection for each request, nor does it have to do dependency resolution.
Instead, only some basic checks on scopes are done and the dependencies are executed with almost no overhead.
Expand All @@ -73,5 +69,7 @@ Instead, only some basic checks on scopes are done and the dependencies are exec
--8<-- "docs/src/solved_dependant.py"
```

To disable scope checks (perhaps something reasonable to do in a web framework after 1 request is processed), you can pass the `validate_scopes=False` parameter to `execute_sync` or `execute_async`.

[binds]: binds.md
[convetxvars]: https://docs.python.org/3/library/contextvars.html
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "di"
version = "0.11.2"
version = "0.12.0"
description = "Autowiring dependency injection"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
Loading

0 comments on commit d55656e

Please sign in to comment.