Skip to content

Commit

Permalink
Reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 30, 2021
1 parent 40c47b5 commit d4a4127
Show file tree
Hide file tree
Showing 26 changed files with 262 additions and 243 deletions.
4 changes: 2 additions & 2 deletions di/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from di.container import Container
from di.dependency import Dependant, DependantProtocol
from di.dependant import Dependant
from di.params import Depends

__all__ = ("Container", "Dependant", "DependantProtocol", "Depends")
__all__ = ("Container", "Dependant", "Depends")
14 changes: 10 additions & 4 deletions di/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
T = TypeVar("T")


def curry_context(call: Callable[..., T]) -> Callable[..., T]:
ctx = contextvars.copy_context()

def inner(*args: Any, **kwargs: Any) -> T:
return ctx.run(functools.partial(call, *args, **kwargs))

return inner


def callable_in_thread_pool(call: Callable[..., T]) -> Callable[..., Awaitable[T]]:
def inner(*args: Any, **kwargs: Any) -> T:
# Ensure we run in the same context
child = functools.partial(call, *args, **kwargs)
context = contextvars.copy_context()
return cast(Awaitable[T], anyio.to_thread.run_sync(context.run, child)) # type: ignore
return cast(Awaitable[T], anyio.to_thread.run_sync(curry_context(call))) # type: ignore

return inner # type: ignore

Expand Down
5 changes: 3 additions & 2 deletions di/_inspect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import types
from dataclasses import dataclass
from functools import lru_cache
from typing import (
Expand Down Expand Up @@ -78,7 +77,9 @@ def get_annotations(call: DependencyProvider) -> Dict[str, Any]:
types_from: DependencyProvider
if inspect.isclass(call):
types_from = call.__init__ # type: ignore
elif not isinstance(call, types.FunctionType) and hasattr(call, "__call__"):
elif not (inspect.isfunction(call) or inspect.ismethod(call)) and hasattr(
call, "__call__"
):
# callable class
types_from = call.__call__ # type: ignore
else:
Expand Down
2 changes: 1 addition & 1 deletion di/_scope_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import Any, Dict, Generic, Hashable, List, TypeVar, Union, overload

from di.dependency import Scope
from di.exceptions import DuplicateScopeError, UnknownScopeError
from di.types.scopes import Scope

T = TypeVar("T")
KT = TypeVar("KT", bound=Hashable)
Expand Down
8 changes: 4 additions & 4 deletions di/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
)

from di._scope_map import ScopeMap
from di._types import FusedContextManager
from di.dependency import (
DependantProtocol,
from di.types import FusedContextManager
from di.types.dependencies import DependantProtocol
from di.types.providers import (
DependencyProvider,
DependencyProviderType,
DependencyType,
Scope,
)
from di.types.scopes import Scope


class ContainerState:
Expand Down
2 changes: 1 addition & 1 deletion di/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Awaitable, Callable, Dict, Generic, List, cast

from di._inspect import DependencyParameter
from di.dependency import Dependency, DependencyType
from di.types.providers import Dependency, DependencyType

_UNSET = object()

Expand Down
65 changes: 37 additions & 28 deletions di/container.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from __future__ import annotations

from collections import deque
from contextlib import ExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
Expand All @@ -7,7 +8,6 @@
from typing import (
Any,
AsyncContextManager,
Awaitable,
Callable,
ContextManager,
Deque,
Expand All @@ -30,28 +30,28 @@
from di._state import ContainerState
from di._task import AsyncTask, SyncTask, Task
from di._topsort import topsort
from di._types import FusedContextManager
from di.dependency import (
from di.exceptions import (
DependencyRegistryError,
DuplicateScopeError,
IncompatibleDependencyError,
ScopeViolationError,
UnknownScopeError,
)
from di.executors import DefaultExecutor
from di.types import FusedContextManager
from di.types.dependencies import DependantProtocol
from di.types.executor import AsyncExecutor, SyncExecutor
from di.types.providers import (
AsyncGeneratorProvider,
CallableProvider,
CoroutineProvider,
DependantProtocol,
Dependency,
DependencyProvider,
DependencyProviderType,
DependencyType,
GeneratorProvider,
Scope,
)
from di.exceptions import (
DependencyRegistryError,
DuplicateScopeError,
IncompatibleDependencyError,
ScopeViolationError,
UnknownScopeError,
)
from di.executor import Executor
from di.executors import ConcurrentAsyncExecutor
from di.types.scopes import Scope


@dataclass
Expand All @@ -72,13 +72,17 @@ class SolvedDependency(Generic[DependencyType]):


class Container:
def __init__(self, executor: Optional[Executor] = None) -> None:
def __init__(
self, executor: Optional[Union[AsyncExecutor, SyncExecutor]] = None
) -> None:
self._context = ContextVar[ContainerState]("context")
state = ContainerState()
state.cached_values.add_scope("container")
state.cached_values.set(Container, self, scope="container")
self._context.set(state)
self._executor = executor or ConcurrentAsyncExecutor()
self._executor: Union[AsyncExecutor, SyncExecutor] = (
executor or DefaultExecutor()
)

@property
def _state(self) -> ContainerState:
Expand Down Expand Up @@ -389,13 +393,15 @@ def execute_sync(

tasks, get_result = self._build_tasks(solved)

res = self._executor.execute(
[{t.compute for t in group} for group in reversed(tasks)], get_result
)
if not hasattr(self._executor, "execute_sync"):
raise TypeError(
"execute_sync requires an executor implementing the SyncExecutor protocol"
)
executor = cast(SyncExecutor, self._executor)

if inspect.isawaitable(res):
raise RuntimeError
return cast(DependencyType, res)
return executor.execute_sync(
[[t.compute for t in group] for group in reversed(tasks)], get_result
)

async def execute_async(
self,
Expand All @@ -413,9 +419,12 @@ async def execute_async(

tasks, get_result = self._build_tasks(solved)

res = self._executor.execute(
[{t.compute for t in group} for group in reversed(tasks)], get_result
if not hasattr(self._executor, "execute_async"):
raise TypeError(
"execute_async requires an executor implementing the AsyncExecutor protocol"
)
executor = cast(AsyncExecutor, self._executor)

return await executor.execute_async(
[[t.compute for t in group] for group in reversed(tasks)], get_result
)
if inspect.isawaitable(res):
return await cast(Awaitable[DependencyType], res)
return cast(DependencyType, res)
139 changes: 12 additions & 127 deletions di/dependency.py → di/dependant.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,22 @@
from __future__ import annotations

import inspect
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
Generator,
Hashable,
Optional,
Protocol,
TypeVar,
Union,
cast,
overload,
runtime_checkable,
)
from typing import Any, Dict, Optional, cast, overload

from di._docstrings import join_docstring_from
from di._inspect import DependencyParameter, get_parameters, infer_call_from_annotation
from di.exceptions import WiringError

DependencyType = TypeVar("DependencyType")

CallableProvider = Callable[..., DependencyType]
CoroutineProvider = Callable[..., Coroutine[Any, Any, DependencyType]]
GeneratorProvider = Callable[..., Generator[DependencyType, None, None]]
AsyncGeneratorProvider = Callable[..., AsyncGenerator[DependencyType, None]]

DependencyProviderType = Union[
CallableProvider[DependencyType],
CoroutineProvider[DependencyType],
GeneratorProvider[DependencyType],
AsyncGeneratorProvider[DependencyType],
]

Scope = Hashable

Dependency = Any

DependencyProvider = Union[
AsyncGeneratorProvider[Dependency],
CoroutineProvider[Dependency],
GeneratorProvider[Dependency],
CallableProvider[Dependency],
]

from di.types.dependencies import DependantProtocol
from di.types.providers import (
AsyncGeneratorProvider,
CallableProvider,
CoroutineProvider,
DependencyProvider,
DependencyProviderType,
DependencyType,
GeneratorProvider,
)
from di.types.scopes import Scope

_VARIABLE_PARAMETER_KINDS = (
inspect.Parameter.VAR_POSITIONAL,
Expand All @@ -67,91 +37,6 @@ def _is_dependant_protocol_instance(o: object) -> bool:
return isinstance(o, DependantProtocol)


@runtime_checkable
class DependantProtocol(Protocol[DependencyType]):
"""A dependant is an object that can provide the container with:
- A hash, to compare itself against other dependants
- A scope
- A callable that can be used to assemble itself
- The dependants that correspond to the keyword arguments of that callable
"""

call: Optional[DependencyProviderType[DependencyType]]
dependencies: Optional[Dict[str, DependencyParameter[DependantProtocol[Any]]]]
scope: Scope
shared: bool

def __hash__(self) -> int:
"""A unique identifier for this dependency.
By default, dependencies are identified by their call attribute.
This can be overriden to introduce other semantics, e.g. to involve the scope or custom attrbiutes
in dependency identification.
"""
raise NotImplementedError

def __eq__(self, o: object) -> bool:
"""Used in conjunction with __hash__ for mapping lookups of dependencies.
Generally, this should have the same semantics as __hash__ but can check for object identity.
"""
raise NotImplementedError

def is_equivalent(self, other: DependantProtocol[Any]) -> bool:
"""Copare two DependantProtocol implementers for equality.
By default, the two are equal only if they share the same callable and scope.
If two dependencies share the same hash but are not equal, the Container will
report an error.
This is commonly caused by using the same callable under two different scopes.
To remedy this, you can either wrap the callable to give it two different hashes/ids,
or you can create a DependantProtocol implementation that overrides __hash__ and/or __eq__.
"""
raise NotImplementedError

def get_dependencies(
self,
) -> Dict[str, DependencyParameter[DependantProtocol[Any]]]:
"""A cache on top of `gather_dependencies()`"""
raise NotImplementedError

def gather_parameters(self) -> Dict[str, inspect.Parameter]:
"""Collect parameters that this dependency needs to construct itself.
Generally, this means introspecting into this dependencies own callable.
"""
raise NotImplementedError

def create_sub_dependant(
self, call: DependencyProvider, scope: Scope, shared: bool
) -> DependantProtocol[Any]:
"""Create a Dependant instance from a sub-dependency of this Dependency.
This is used in the scenario where a transient dependency is inferred from a type hint.
For example:
>>> class Foo:
>>> ...
>>> def foo_factory(foo: Foo) -> Foo:
>>> return foo
>>> def parent(foo: Dependency(foo_factory)):
>>> ...
In this scenario, `Dependency(foo_factory)` will call `create_sub_dependant(Foo)`.
"""
raise NotImplementedError

def infer_call_from_annotation(
self, param: inspect.Parameter
) -> DependencyProvider:
"""Called when the dependency was not explicitly passed a callable.
It is important to note that param in this context refers to the parameter in this
Dependant's parent.
For example, in the case of `def func(thing: Something = Dependency())` this method
will be called with a Parameter corresponding to Something.
"""
raise NotImplementedError


class Dependant(DependantProtocol[DependencyType]):
@overload
def __init__(
Expand Down
16 changes: 0 additions & 16 deletions di/executor.py

This file was deleted.

Loading

0 comments on commit d4a4127

Please sign in to comment.