Skip to content

Commit

Permalink
Clean up + store tasks in solved dependencies for lower execution ove…
Browse files Browse the repository at this point in the history
…rhead
  • Loading branch information
adriangb committed Sep 30, 2021
1 parent d4a4127 commit 05efa72
Show file tree
Hide file tree
Showing 15 changed files with 272 additions and 229 deletions.
4 changes: 2 additions & 2 deletions di/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def inner(*args: Any, **kwargs: Any) -> T:


def callable_in_thread_pool(call: Callable[..., T]) -> Callable[..., Awaitable[T]]:
def inner(*args: Any, **kwargs: Any) -> T:
def inner(*args: Any, **kwargs: Any) -> Awaitable[T]:
return cast(Awaitable[T], anyio.to_thread.run_sync(curry_context(call))) # type: ignore

return inner # type: ignore
return inner


def gurantee_awaitable(
Expand Down
3 changes: 3 additions & 0 deletions di/_inspect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import inspect
from dataclasses import dataclass
from functools import lru_cache
Expand Down Expand Up @@ -50,6 +51,8 @@ class DependencyParameter(Generic[T]):

@lru_cache(maxsize=4096)
def is_coroutine_callable(call: DependencyProvider) -> bool:
if isinstance(call, functools.partial):
call = call.func
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
Expand Down
64 changes: 64 additions & 0 deletions di/_local_scope_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import contextvars
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Optional,
Type,
Union,
cast,
)

if TYPE_CHECKING:
from di._state import ContainerState

from di.types import FusedContextManager
from di.types.scopes import Scope


class LocalScopeContext(FusedContextManager[None]):
def __init__(
self, context: contextvars.ContextVar[ContainerState], scope: Scope
) -> None:
self.context = context
self.scope = scope
self.token: Optional[contextvars.Token[ContainerState]] = None

def __enter__(self):
current = self.context.get()
new = current.copy()
self.token = self.context.set(new)
self.state_cm = cast(ContextManager[None], new.enter_scope(self.scope))
self.state_cm.__enter__()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Union[None, bool]:
if self.token is not None:
self.context.reset(self.token)
cm = cast(ContextManager[None], self.state_cm)
return cm.__exit__(exc_type, exc_value, traceback)

async def __aenter__(self):
current = self.context.get()
new = current.copy()
self.token = self.context.set(new)
self.state_cm = cast(AsyncContextManager[None], new.enter_scope(self.scope))
await self.state_cm.__aenter__()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Union[None, bool]:
if self.token is not None:
self.context.reset(self.token)
cm = cast(AsyncContextManager[None], self.state_cm)
return await cm.__aexit__(exc_type, exc_value, traceback)
2 changes: 2 additions & 0 deletions di/_scope_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class ScopeMap(Generic[KT, VT]):
ChainMap also doesn't allow you to set values anywhere but the left mapping, and we need to set values in arbitrary mappings.
"""

__slots__ = ("mappings",)

def __init__(self) -> None:
self.mappings: Dict[Scope, Dict[KT, VT]] = {}

Expand Down
5 changes: 4 additions & 1 deletion di/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from di.types.scopes import Scope


class ContainerState:
class ContainerState(object):

__slots__ = ("binds", "cached_values", "stacks")

def __init__(self) -> None:
self.binds: Dict[DependencyProvider, DependantProtocol[Any]] = {}
self.cached_values = ScopeMap[DependencyProvider, Any]()
Expand Down
111 changes: 78 additions & 33 deletions di/_task.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,44 @@
from __future__ import annotations

from typing import Any, Awaitable, Callable, Dict, Generic, List, cast
from contextlib import ExitStack, asynccontextmanager, contextmanager
from typing import Any, Dict, Generic, List, Tuple, cast

from di._inspect import DependencyParameter
from di.types.providers import Dependency, DependencyType
from di._inspect import DependencyParameter, is_coroutine_callable, is_gen_callable
from di._state import ContainerState
from di.exceptions import IncompatibleDependencyError
from di.types.dependencies import DependantProtocol
from di.types.providers import (
AsyncGeneratorProvider,
CallableProvider,
CoroutineProvider,
Dependency,
DependencyType,
GeneratorProvider,
)

_UNSET = object()


class Task(Generic[DependencyType]):
def __init__(
self,
dependant: DependantProtocol[DependencyType],
dependencies: Dict[str, DependencyParameter[Task[Dependency]]],
) -> None:
self.dependant = dependant
self.dependencies = dependencies
self._result: Any = _UNSET

def _gather_params(self) -> Tuple[List[Dependency], Dict[str, Dependency]]:
positional: List[Dependency] = []
keyword: Dict[str, Dependency] = {}
for param_name, dep in self.dependencies.items():
if dep.parameter.kind is dep.parameter.kind.POSITIONAL_ONLY:
positional.append(dep.dependency.get_result())
else:
keyword[param_name] = dep.dependency.get_result()
return positional, keyword

def get_result(self) -> DependencyType:
if self._result is _UNSET:
raise ValueError(
Expand All @@ -25,40 +48,62 @@ def get_result(self) -> DependencyType:


class AsyncTask(Task[DependencyType]):
def __init__(
self,
call: Callable[..., Awaitable[DependencyType]],
dependencies: Dict[str, DependencyParameter[Task[Dependency]]],
) -> None:
self.call = call
super().__init__(dependencies)
async def compute(self, state: ContainerState) -> None:
assert self.dependant.call is not None
args, kwargs = self._gather_params()

async def compute(self) -> None:
positional: List[Task[Dependency]] = []
keyword: Dict[str, Task[Dependency]] = {}
for param_name, dep in self.dependencies.items():
if dep.parameter.kind is dep.parameter.kind.POSITIONAL_ONLY:
positional.append(dep.dependency.get_result())
if self.dependant.shared and state.cached_values.contains(self.dependant.call):
# use cached value
self._result = state.cached_values.get(self.dependant.call)
else:
if is_coroutine_callable(self.dependant.call):
self._result = await cast(
CoroutineProvider[DependencyType], self.dependant.call
)(*args, **kwargs)
else:
keyword[param_name] = dep.dependency.get_result()
self._result = await self.call(*positional, **keyword)
stack = state.stacks[self.dependant.scope]
if isinstance(stack, ExitStack):
raise IncompatibleDependencyError(
f"The dependency {self.dependant} is an awaitable dependency"
f" and canot be used in the sync scope {self.dependant.scope}"
)
self._result = await stack.enter_async_context(
asynccontextmanager(
cast(
AsyncGeneratorProvider[DependencyType],
self.dependant.call,
)
)(*args, **kwargs)
)
if self.dependant.shared:
# caching is allowed, now that we have a value we can save it and start using the cache
state.cached_values.set(
self.dependant.call, self._result, scope=self.dependant.scope
)


class SyncTask(Task[DependencyType]):
def __init__(
self,
call: Callable[..., DependencyType],
dependencies: Dict[str, DependencyParameter[Task[Dependency]]],
) -> None:
self.call = call
super().__init__(dependencies)
def compute(self, state: ContainerState) -> None:
assert self.dependant.call is not None
args, kwargs = self._gather_params()

def compute(self) -> None:
positional: List[Task[Dependency]] = []
keyword: Dict[str, Task[Dependency]] = {}
for param_name, dep in self.dependencies.items():
if dep.parameter.kind is dep.parameter.kind.POSITIONAL_ONLY:
positional.append(dep.dependency.get_result())
if self.dependant.shared and state.cached_values.contains(self.dependant.call):
# use cached value
self._result = state.cached_values.get(self.dependant.call)
else:
if not is_gen_callable(self.dependant.call):
self._result = cast(
CallableProvider[DependencyType], self.dependant.call
)(*args, **kwargs)
else:
keyword[param_name] = dep.dependency.get_result()
self._result = self.call(*positional, **keyword)
stack = state.stacks[self.dependant.scope]
self._result = stack.enter_context(
contextmanager(
cast(GeneratorProvider[DependencyType], self.dependant.call)
)(*args, **kwargs)
)
if self.dependant.shared:
# caching is allowed, now that we have a value we can save it and start using the cache
state.cached_values.set(
self.dependant.call, self._result, scope=self.dependant.scope
)
Loading

0 comments on commit 05efa72

Please sign in to comment.