From 05efa7287ad1c0943295ca9f37d3bead2fdbc4af Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 30 Sep 2021 11:07:23 -0500 Subject: [PATCH] Clean up + store tasks in solved dependencies for lower execution overhead --- di/_concurrency.py | 4 +- di/_inspect.py | 3 + di/_local_scope_context.py | 64 ++++++++ di/_scope_map.py | 2 + di/_state.py | 5 +- di/_task.py | 111 ++++++++++---- di/container.py | 226 +++++++--------------------- di/executors.py | 21 ++- di/types/executor.py | 4 +- di/types/solved.py | 35 +++++ docs/src/gather_deps_example.py | 2 +- docs/src/solved_dependant.py | 2 +- pyproject.toml | 4 +- tests/test_execute.py | 4 +- tests/test_get_flat_dependencies.py | 14 +- 15 files changed, 272 insertions(+), 229 deletions(-) create mode 100644 di/_local_scope_context.py create mode 100644 di/types/solved.py diff --git a/di/_concurrency.py b/di/_concurrency.py index b93c5058..88dd112e 100644 --- a/di/_concurrency.py +++ b/di/_concurrency.py @@ -19,10 +19,10 @@ def inner(*args: Any, **kwargs: Any) -> T: def callable_in_thread_pool(call: Callable[..., T]) -> Callable[..., Awaitable[T]]: - def inner(*args: Any, **kwargs: Any) -> T: + def inner(*args: Any, **kwargs: Any) -> Awaitable[T]: return cast(Awaitable[T], anyio.to_thread.run_sync(curry_context(call))) # type: ignore - return inner # type: ignore + return inner def gurantee_awaitable( diff --git a/di/_inspect.py b/di/_inspect.py index f21e35c3..35720862 100644 --- a/di/_inspect.py +++ b/di/_inspect.py @@ -1,3 +1,4 @@ +import functools import inspect from dataclasses import dataclass from functools import lru_cache @@ -50,6 +51,8 @@ class DependencyParameter(Generic[T]): @lru_cache(maxsize=4096) def is_coroutine_callable(call: DependencyProvider) -> bool: + if isinstance(call, functools.partial): + call = call.func if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): diff --git a/di/_local_scope_context.py b/di/_local_scope_context.py new file mode 100644 index 00000000..5a1817cb --- /dev/null +++ b/di/_local_scope_context.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import contextvars +from types import TracebackType +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + ContextManager, + Optional, + Type, + Union, + cast, +) + +if TYPE_CHECKING: + from di._state import ContainerState + +from di.types import FusedContextManager +from di.types.scopes import Scope + + +class LocalScopeContext(FusedContextManager[None]): + def __init__( + self, context: contextvars.ContextVar[ContainerState], scope: Scope + ) -> None: + self.context = context + self.scope = scope + self.token: Optional[contextvars.Token[ContainerState]] = None + + def __enter__(self): + current = self.context.get() + new = current.copy() + self.token = self.context.set(new) + self.state_cm = cast(ContextManager[None], new.enter_scope(self.scope)) + self.state_cm.__enter__() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Union[None, bool]: + if self.token is not None: + self.context.reset(self.token) + cm = cast(ContextManager[None], self.state_cm) + return cm.__exit__(exc_type, exc_value, traceback) + + async def __aenter__(self): + current = self.context.get() + new = current.copy() + self.token = self.context.set(new) + self.state_cm = cast(AsyncContextManager[None], new.enter_scope(self.scope)) + await self.state_cm.__aenter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Union[None, bool]: + if self.token is not None: + self.context.reset(self.token) + cm = cast(AsyncContextManager[None], self.state_cm) + return await cm.__aexit__(exc_type, exc_value, traceback) diff --git a/di/_scope_map.py b/di/_scope_map.py index e2f97bff..074914db 100644 --- a/di/_scope_map.py +++ b/di/_scope_map.py @@ -29,6 +29,8 @@ class ScopeMap(Generic[KT, VT]): ChainMap also doesn't allow you to set values anywhere but the left mapping, and we need to set values in arbitrary mappings. """ + __slots__ = ("mappings",) + def __init__(self) -> None: self.mappings: Dict[Scope, Dict[KT, VT]] = {} diff --git a/di/_state.py b/di/_state.py index 65d621ac..1b71fea7 100644 --- a/di/_state.py +++ b/di/_state.py @@ -25,7 +25,10 @@ from di.types.scopes import Scope -class ContainerState: +class ContainerState(object): + + __slots__ = ("binds", "cached_values", "stacks") + def __init__(self) -> None: self.binds: Dict[DependencyProvider, DependantProtocol[Any]] = {} self.cached_values = ScopeMap[DependencyProvider, Any]() diff --git a/di/_task.py b/di/_task.py index 8b0dcfee..7bbece78 100644 --- a/di/_task.py +++ b/di/_task.py @@ -1,9 +1,20 @@ from __future__ import annotations -from typing import Any, Awaitable, Callable, Dict, Generic, List, cast +from contextlib import ExitStack, asynccontextmanager, contextmanager +from typing import Any, Dict, Generic, List, Tuple, cast -from di._inspect import DependencyParameter -from di.types.providers import Dependency, DependencyType +from di._inspect import DependencyParameter, is_coroutine_callable, is_gen_callable +from di._state import ContainerState +from di.exceptions import IncompatibleDependencyError +from di.types.dependencies import DependantProtocol +from di.types.providers import ( + AsyncGeneratorProvider, + CallableProvider, + CoroutineProvider, + Dependency, + DependencyType, + GeneratorProvider, +) _UNSET = object() @@ -11,11 +22,23 @@ class Task(Generic[DependencyType]): def __init__( self, + dependant: DependantProtocol[DependencyType], dependencies: Dict[str, DependencyParameter[Task[Dependency]]], ) -> None: + self.dependant = dependant self.dependencies = dependencies self._result: Any = _UNSET + def _gather_params(self) -> Tuple[List[Dependency], Dict[str, Dependency]]: + positional: List[Dependency] = [] + keyword: Dict[str, Dependency] = {} + for param_name, dep in self.dependencies.items(): + if dep.parameter.kind is dep.parameter.kind.POSITIONAL_ONLY: + positional.append(dep.dependency.get_result()) + else: + keyword[param_name] = dep.dependency.get_result() + return positional, keyword + def get_result(self) -> DependencyType: if self._result is _UNSET: raise ValueError( @@ -25,40 +48,62 @@ def get_result(self) -> DependencyType: class AsyncTask(Task[DependencyType]): - def __init__( - self, - call: Callable[..., Awaitable[DependencyType]], - dependencies: Dict[str, DependencyParameter[Task[Dependency]]], - ) -> None: - self.call = call - super().__init__(dependencies) + async def compute(self, state: ContainerState) -> None: + assert self.dependant.call is not None + args, kwargs = self._gather_params() - async def compute(self) -> None: - positional: List[Task[Dependency]] = [] - keyword: Dict[str, Task[Dependency]] = {} - for param_name, dep in self.dependencies.items(): - if dep.parameter.kind is dep.parameter.kind.POSITIONAL_ONLY: - positional.append(dep.dependency.get_result()) + if self.dependant.shared and state.cached_values.contains(self.dependant.call): + # use cached value + self._result = state.cached_values.get(self.dependant.call) + else: + if is_coroutine_callable(self.dependant.call): + self._result = await cast( + CoroutineProvider[DependencyType], self.dependant.call + )(*args, **kwargs) else: - keyword[param_name] = dep.dependency.get_result() - self._result = await self.call(*positional, **keyword) + stack = state.stacks[self.dependant.scope] + if isinstance(stack, ExitStack): + raise IncompatibleDependencyError( + f"The dependency {self.dependant} is an awaitable dependency" + f" and canot be used in the sync scope {self.dependant.scope}" + ) + self._result = await stack.enter_async_context( + asynccontextmanager( + cast( + AsyncGeneratorProvider[DependencyType], + self.dependant.call, + ) + )(*args, **kwargs) + ) + if self.dependant.shared: + # caching is allowed, now that we have a value we can save it and start using the cache + state.cached_values.set( + self.dependant.call, self._result, scope=self.dependant.scope + ) class SyncTask(Task[DependencyType]): - def __init__( - self, - call: Callable[..., DependencyType], - dependencies: Dict[str, DependencyParameter[Task[Dependency]]], - ) -> None: - self.call = call - super().__init__(dependencies) + def compute(self, state: ContainerState) -> None: + assert self.dependant.call is not None + args, kwargs = self._gather_params() - def compute(self) -> None: - positional: List[Task[Dependency]] = [] - keyword: Dict[str, Task[Dependency]] = {} - for param_name, dep in self.dependencies.items(): - if dep.parameter.kind is dep.parameter.kind.POSITIONAL_ONLY: - positional.append(dep.dependency.get_result()) + if self.dependant.shared and state.cached_values.contains(self.dependant.call): + # use cached value + self._result = state.cached_values.get(self.dependant.call) + else: + if not is_gen_callable(self.dependant.call): + self._result = cast( + CallableProvider[DependencyType], self.dependant.call + )(*args, **kwargs) else: - keyword[param_name] = dep.dependency.get_result() - self._result = self.call(*positional, **keyword) + stack = state.stacks[self.dependant.scope] + self._result = stack.enter_context( + contextmanager( + cast(GeneratorProvider[DependencyType], self.dependant.call) + )(*args, **kwargs) + ) + if self.dependant.shared: + # caching is allowed, now that we have a value we can save it and start using the cache + state.cached_values.set( + self.dependant.call, self._result, scope=self.dependant.scope + ) diff --git a/di/container.py b/di/container.py index 6080839d..b1ed7e59 100644 --- a/di/container.py +++ b/di/container.py @@ -1,22 +1,17 @@ from __future__ import annotations +import functools from collections import deque -from contextlib import ExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar -from dataclasses import dataclass -from types import TracebackType from typing import ( Any, - AsyncContextManager, Callable, ContextManager, Deque, Dict, - Generic, List, Optional, Tuple, - Type, Union, cast, ) @@ -25,15 +20,14 @@ DependencyParameter, is_async_gen_callable, is_coroutine_callable, - is_gen_callable, ) +from di._local_scope_context import LocalScopeContext from di._state import ContainerState from di._task import AsyncTask, SyncTask, Task from di._topsort import topsort from di.exceptions import ( DependencyRegistryError, DuplicateScopeError, - IncompatibleDependencyError, ScopeViolationError, UnknownScopeError, ) @@ -42,33 +36,13 @@ from di.types.dependencies import DependantProtocol from di.types.executor import AsyncExecutor, SyncExecutor from di.types.providers import ( - AsyncGeneratorProvider, - CallableProvider, - CoroutineProvider, Dependency, DependencyProvider, DependencyProviderType, DependencyType, - GeneratorProvider, ) from di.types.scopes import Scope - - -@dataclass -class SolvedDependency(Generic[DependencyType]): - """Representation of a fully solved dependency. - - A fully solved dependency consists of: - - A DAG of sub-dependency paramters. - - A topologically sorted order of execution, where each sublist represents a - group of dependencies that can be executed in parallel. - """ - - dependency: DependantProtocol[DependencyType] - dag: Dict[ - DependantProtocol[Any], Dict[str, DependencyParameter[DependantProtocol[Any]]] - ] - topsort: List[List[DependantProtocol[Any]]] +from di.types.solved import SolvedDependency class Container: @@ -102,45 +76,7 @@ def enter_local_scope(self, scope: Scope) -> FusedContextManager[None]: """ if scope in self._state.stacks: raise DuplicateScopeError(f"Scope {scope} has already been entered!") - - container = self - - class LocalScopeContext(FusedContextManager[None]): - def __enter__(self): - current = container._state - new = current.copy() - self.token = container._context.set(new) - self.state_cm = cast(ContextManager[None], new.enter_scope(scope)) - self.state_cm.__enter__() - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> Union[None, bool]: - container._context.reset(self.token) - cm = cast(ContextManager[None], self.state_cm) - return cm.__exit__(exc_type, exc_value, traceback) - - async def __aenter__(self): - current = container._state - new = current.copy() - self.token = container._context.set(new) - self.state_cm = cast(AsyncContextManager[None], new.enter_scope(scope)) - await self.state_cm.__aenter__() - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> Union[None, bool]: - container._context.reset(self.token) - cm = cast(AsyncContextManager[None], self.state_cm) - return await cm.__aexit__(exc_type, exc_value, traceback) - - return LocalScopeContext() + return LocalScopeContext(self._context, scope) def bind( self, @@ -171,7 +107,7 @@ def solve( will not be changing between calls. """ - if dependency.call in self._state.binds: # type: ignore + if dependency.call in self._state.binds: dependency = self._state.binds[dependency.call] # type: ignore param_graph: Dict[ @@ -225,18 +161,16 @@ def check_equivalent(dep: DependantProtocol[Any]): if subdep not in dep_registry: q.append(subdep) - groups = topsort(dependency, dep_dag) - return SolvedDependency[DependencyType]( - dependency=dependency, dag=param_graph, topsort=groups + topsorted_groups = topsort(dependency, dep_dag) + tasks, get_results = self._build_tasks( + topsorted_groups, dependency, param_graph + ) + return SolvedDependency( + dependency=dependency, + dag=param_graph, + _tasks=tasks, + _get_results=get_results, ) - - def get_flat_subdependants( - self, solved: SolvedDependency[Any] - ) -> List[DependantProtocol[Any]]: - """Get an exhaustive list of all of the dependencies of this dependency, - in no particular order. - """ - return [dep for group in solved.topsort[1:] for dep in group] def _build_task( self, @@ -244,7 +178,6 @@ def _build_task( tasks: Dict[ DependantProtocol[Any], Union[AsyncTask[Dependency], SyncTask[Dependency]] ], - state: ContainerState, dag: Dict[ DependantProtocol[Any], Dict[str, DependencyParameter[DependantProtocol[Any]]], @@ -261,73 +194,34 @@ def _build_task( if is_async_gen_callable(dependency.call) or is_coroutine_callable( dependency.call ): - - async def async_call(*args: Any, **kwargs: Any) -> DependencyType: - assert dependency.call is not None - if dependency.shared and state.cached_values.contains(dependency.call): - # use cached value - res = state.cached_values.get(dependency.call) - else: - if is_coroutine_callable(dependency.call): - res = await cast( - CoroutineProvider[DependencyType], dependency.call - )(*args, **kwargs) - else: - stack = state.stacks[dependency.scope] - if isinstance(stack, ExitStack): - raise IncompatibleDependencyError( - f"The dependency {dependency} is an awaitable dependency" - f" and canot be used in the sync scope {dependency.scope}" - ) - res = await stack.enter_async_context( - asynccontextmanager( - cast( - AsyncGeneratorProvider[DependencyType], - dependency.call, - ) - )(*args, **kwargs) - ) - if dependency.shared: - # caching is allowed, now that we have a value we can save it and start using the cache - state.cached_values.set( - dependency.call, res, scope=dependency.scope - ) - - return cast(DependencyType, res) - - return AsyncTask[DependencyType]( - call=async_call, dependencies=task_dependencies - ) + return AsyncTask(dependant=dependency, dependencies=task_dependencies) else: - # sync - def sync_call(*args: Any, **kwargs: Any) -> DependencyType: - assert dependency.call is not None - if dependency.shared and state.cached_values.contains(dependency.call): - # use cached value - res = state.cached_values.get(dependency.call) - else: - if not is_gen_callable(dependency.call): - res = cast(CallableProvider[DependencyType], dependency.call)( - *args, **kwargs - ) - else: - stack = state.stacks[dependency.scope] - res = stack.enter_context( - contextmanager( - cast(GeneratorProvider[DependencyType], dependency.call) - )(*args, **kwargs) - ) - if dependency.shared: - # caching is allowed, now that we have a value we can save it and start using the cache - state.cached_values.set( - dependency.call, res, scope=dependency.scope - ) - - return cast(DependencyType, res) - - return SyncTask[DependencyType]( - call=sync_call, dependencies=task_dependencies - ) + return SyncTask(dependant=dependency, dependencies=task_dependencies) + + def _build_tasks( + self, + topsort: List[List[DependantProtocol[Any]]], + dependency: DependantProtocol[DependencyType], + dag: Dict[ + DependantProtocol[Any], + Dict[str, DependencyParameter[DependantProtocol[Any]]], + ], + ) -> Tuple[ + List[List[Union[AsyncTask[Dependency], SyncTask[Dependency]]]], + Callable[[], DependencyType], + ]: + tasks: Dict[ + DependantProtocol[Any], Union[AsyncTask[Dependency], SyncTask[Dependency]] + ] = {} + for group in reversed(topsort): + for dep in group: + if dep not in tasks: + tasks[dep] = self._build_task(dep, tasks, dag) + get_result = tasks[dependency].get_result + return ( + list(reversed([[tasks[dep] for dep in group] for group in topsort])), + get_result, + ) def _validate_scopes(self, solved: SolvedDependency[Dependency]) -> None: """Validate that dependencies all have a valid scope and @@ -361,22 +255,6 @@ def check_scope(dep: DependantProtocol[Any]) -> None: check_scope(subdep) check_is_inner(dep, subdep) - def _build_tasks( - self, solved: SolvedDependency[DependencyType] - ) -> Tuple[ - List[List[Union[AsyncTask[Dependency], SyncTask[Dependency]]]], - Callable[[], DependencyType], - ]: - tasks: Dict[ - DependantProtocol[Any], Union[AsyncTask[Dependency], SyncTask[Dependency]] - ] = {} - for group in reversed(solved.topsort): - for dep in group: - if dep not in tasks: - tasks[dep] = self._build_task(dep, tasks, self._state, solved.dag) - get_result = tasks[solved.dependency].get_result - return [[tasks[dep] for dep in group] for group in solved.topsort], get_result - def execute_sync( self, solved: SolvedDependency[DependencyType], @@ -391,7 +269,7 @@ def execute_sync( if validate_scopes: self._validate_scopes(solved) - tasks, get_result = self._build_tasks(solved) + tasks, get_results = solved._tasks, solved._get_results # type: ignore if not hasattr(self._executor, "execute_sync"): raise TypeError( @@ -399,9 +277,12 @@ def execute_sync( ) executor = cast(SyncExecutor, self._executor) - return executor.execute_sync( - [[t.compute for t in group] for group in reversed(tasks)], get_result - ) + stateful_tasks = [ + [functools.partial(t.compute, self._state) for t in group] + for group in tasks + ] + + return executor.execute_sync(stateful_tasks, get_results) # type: ignore async def execute_async( self, @@ -417,7 +298,7 @@ async def execute_async( if validate_scopes: self._validate_scopes(solved) - tasks, get_result = self._build_tasks(solved) + tasks, get_results = solved._tasks, solved._get_results # type: ignore if not hasattr(self._executor, "execute_async"): raise TypeError( @@ -425,6 +306,9 @@ async def execute_async( ) executor = cast(AsyncExecutor, self._executor) - return await executor.execute_async( - [[t.compute for t in group] for group in reversed(tasks)], get_result - ) + stateful_tasks = [ + [functools.partial(t.compute, self._state) for t in group] + for group in tasks + ] + + return await executor.execute_async(stateful_tasks, get_results) # type: ignore diff --git a/di/executors.py b/di/executors.py index d9f7e973..a44e2e33 100644 --- a/di/executors.py +++ b/di/executors.py @@ -1,4 +1,5 @@ import concurrent.futures +import functools import inspect import typing @@ -12,8 +13,15 @@ T = typing.TypeVar("T") -def _all_sync(tasks: typing.Collection[Task]) -> bool: - return not any(inspect.iscoroutinefunction(task) for task in tasks) +def _all_sync(tasks: typing.List[Task]) -> bool: + for task in tasks: + if isinstance(task, functools.partial): + call = task.func # type: ignore + else: + call = task + if inspect.iscoroutinefunction(call): + return False + return True class DefaultExecutor(AsyncExecutor, SyncExecutor): @@ -22,7 +30,7 @@ def __init__(self) -> None: def execute_sync( self, - tasks: typing.List[typing.Collection[Task]], + tasks: typing.List[typing.List[Task]], get_result: typing.Callable[[], ResultType], ) -> ResultType: for task_group in tasks: @@ -43,7 +51,7 @@ def execute_sync( "Cannot execute async dependencies in a SyncExecutor" ) else: - v = next(iter(task_group))() + v = task_group[0]() if inspect.isawaitable(v): raise TypeError( "Cannot execute async dependencies in a SyncExecutor" @@ -52,7 +60,7 @@ def execute_sync( async def execute_async( self, - tasks: typing.List[typing.Collection[Task]], + tasks: typing.List[typing.List[Task]], get_result: typing.Callable[[], ResultType], ) -> ResultType: tg: typing.Optional[anyio.abc.TaskGroup] = None @@ -67,8 +75,7 @@ async def execute_async( for task in task_group: tg.start_soon(gurantee_awaitable(task)) # type: ignore else: - task = next(iter(task_group)) - res = task() + res = next(iter(task_group))() if res is not None and inspect.isawaitable(res): await res return get_result() diff --git a/di/types/executor.py b/di/types/executor.py index 89f8e773..1b003aed 100644 --- a/di/types/executor.py +++ b/di/types/executor.py @@ -10,7 +10,7 @@ class SyncExecutor(typing.Protocol): def execute_sync( self, - tasks: typing.List[typing.Collection[Task]], + tasks: typing.List[typing.List[Task]], get_result: typing.Callable[[], ResultType], ) -> ResultType: raise NotImplementedError @@ -19,7 +19,7 @@ def execute_sync( class AsyncExecutor(typing.Protocol): async def execute_async( self, - tasks: typing.List[typing.Collection[Task]], + tasks: typing.List[typing.List[Task]], get_result: typing.Callable[[], ResultType], ) -> ResultType: raise NotImplementedError diff --git a/di/types/solved.py b/di/types/solved.py new file mode 100644 index 00000000..4754e596 --- /dev/null +++ b/di/types/solved.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import dataclasses +import functools +from typing import Any, Callable, Dict, Generic, List, Union + +from di._inspect import DependencyParameter +from di._task import AsyncTask, SyncTask +from di.types.dependencies import DependantProtocol +from di.types.providers import Dependency, DependencyType + + +@dataclasses.dataclass(frozen=True) +class SolvedDependency(Generic[DependencyType]): + """Representation of a fully solved dependency. + + A fully solved dependency consists of: + - A DAG of sub-dependency paramters. + - A topologically sorted order of execution, where each sublist represents a + group of dependencies that can be executed in parallel. + """ + + dependency: DependantProtocol[DependencyType] + dag: Dict[ + DependantProtocol[Any], Dict[str, DependencyParameter[DependantProtocol[Any]]] + ] + _tasks: List[List[Union[AsyncTask[Dependency], SyncTask[Dependency]]]] + _get_results: Callable[[], DependencyType] + + @functools.cached_property + def flat_subdependants(self) -> List[DependantProtocol[Any]]: + """Get an exhaustive list of all of the dependencies of this dependency, + in no particular order. + """ + return list(self.dag.keys() - {self.dependency}) diff --git a/docs/src/gather_deps_example.py b/docs/src/gather_deps_example.py index 5a53a755..6f509ece 100644 --- a/docs/src/gather_deps_example.py +++ b/docs/src/gather_deps_example.py @@ -37,7 +37,7 @@ async def web_framework() -> None: # about not knowing how to build a Request with container.bind(Dependant(lambda: Request(scopes=[])), Request): scopes = gather_scopes( - container.get_flat_subdependants(container.solve(Dependant(controller))) + container.solve(Dependant(controller)).flat_subdependants ) assert set(scopes) == {"scope1", "scope2"} diff --git a/docs/src/solved_dependant.py b/docs/src/solved_dependant.py index 72f4ddc9..d8d4ef41 100644 --- a/docs/src/solved_dependant.py +++ b/docs/src/solved_dependant.py @@ -4,7 +4,7 @@ import anyio from di import Container, Dependant -from di.container import SolvedDependency +from di.types.solved import SolvedDependency T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index 5ee43901..452035cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "di" -version = "0.4.1" +version = "0.4.2" description = "Autowiring dependency injection" authors = ["Adrian Garcia Badaracco "] readme = "README.md" @@ -44,7 +44,7 @@ mkautodoc = "~0" mike = "~1" [build-system] -requires = ["poetry-core>=1.0.0"] +requires = ["poetry-core>=1.0.6"] build-backend = "poetry.core.masonry.api" [tool.isort] diff --git a/tests/test_execute.py b/tests/test_execute.py index 028769aa..01047274 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -160,7 +160,7 @@ def sync_callable_func_slow(counter: Counter) -> None: start = time.time() with counter.acquire(): while counter.counter < 2: - if time.time() - start > 10: + if time.time() - start > 0.5: raise TimeoutError( "Tasks did not execute concurrently" ) # pragma: no cover @@ -172,7 +172,7 @@ async def async_callable_func_slow(counter: Counter) -> None: start = time.time() with counter.acquire(): while counter.counter < 2: - if time.time() - start > 10: + if time.time() - start > 0.5: raise TimeoutError( "Tasks did not execute concurrently" ) # pragma: no cover diff --git a/tests/test_get_flat_dependencies.py b/tests/test_get_flat_dependencies.py index 2915ab25..d9eedfe8 100644 --- a/tests/test_get_flat_dependencies.py +++ b/tests/test_get_flat_dependencies.py @@ -47,28 +47,28 @@ async def test_get_flat_dependencies(): async with container.enter_global_scope("dummy"): assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call7))), + container.solve(Dependant(call=call7)).flat_subdependants, [call1, call2, call3, call4, call6], ) assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call6))), + container.solve(Dependant(call=call6)).flat_subdependants, [call1, call2, call3, call4], ) assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call5))), + container.solve(Dependant(call=call5)).flat_subdependants, [call1, call2, call3, call4], ) assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call4))), + container.solve(Dependant(call=call4)).flat_subdependants, [call1, call2, call3], ) assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call3))), [] + container.solve(Dependant(call=call3)).flat_subdependants, [] ) assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call2))), + container.solve(Dependant(call=call2)).flat_subdependants, [call1], ) assert_compare_call( - container.get_flat_subdependants(container.solve(Dependant(call=call1))), [] + container.solve(Dependant(call=call1)).flat_subdependants, [] )