Skip to content

Commit

Permalink
perf: remove tasks that completed by value
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Oct 18, 2021
1 parent ad3b6de commit 2bb419d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 77 deletions.
34 changes: 11 additions & 23 deletions di/_task.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import typing
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Tuple, cast

from di._inspect import is_async_gen_callable, is_gen_callable
from di._state import ContainerState
from di.exceptions import IncompatibleDependencyError
from di.types.dependencies import DependantProtocol, DependencyParameter
from di.types.executor import Values
from di.types.providers import (
AsyncGeneratorProvider,
CallableProvider,
Expand All @@ -17,6 +17,10 @@
)


class Value(typing.NamedTuple):
value: Any


class Task(Generic[DependencyType]):
__slots__ = ("dependant", "dependencies")

Expand All @@ -41,19 +45,6 @@ def _gather_params(
keyword[dep.parameter.name] = results[dep.dependency]
return positional, keyword

def use_value(
self, state: ContainerState, results: Dict[Task[Any], Any], values: Values
) -> bool:
assert self.dependant.call is not None
if self.dependant.call in values:
results[self] = values[self.dependant.call]
return True
if self.dependant.share and state.cached_values.contains(self.dependant.call):
# use cached value
results[self] = state.cached_values.get(self.dependant.call)
return True
return False


class AsyncTask(Task[DependencyType]):
__slots__ = ()
Expand All @@ -62,14 +53,12 @@ async def compute(
self,
state: ContainerState,
results: Dict[Task[Any], Any],
values: Values,
) -> None:
if self.use_value(state, results, values):
return # pragma: no cover
args, kwargs = self._gather_params(results)

assert self.dependant.call is not None
call = self.dependant.call
args, kwargs = self._gather_params(results)
if is_async_gen_callable(self.dependant.call):
if is_async_gen_callable(call):
stack = state.stacks[self.dependant.scope]
if not isinstance(stack, AsyncExitStack):
raise IncompatibleDependencyError(
Expand Down Expand Up @@ -99,13 +88,12 @@ def compute(
self,
state: ContainerState,
results: Dict[Task[Any], Any],
values: Values,
) -> None:
if self.use_value(state, results, values):
return

args, kwargs = self._gather_params(results)

assert self.dependant.call is not None
call = self.dependant.call
args, kwargs = self._gather_params(results)
if is_gen_callable(self.dependant.call):
if TYPE_CHECKING:
call = cast(GeneratorProvider[DependencyType], call)
Expand Down
67 changes: 51 additions & 16 deletions di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
cast,
Expand All @@ -31,13 +32,19 @@
from di.executors import DefaultExecutor
from di.types import FusedContextManager
from di.types.dependencies import DependantProtocol, DependencyParameter
from di.types.executor import AsyncExecutor, SyncExecutor, Values
from di.types.providers import DependencyProviderType, DependencyType
from di.types.executor import AsyncExecutor, SyncExecutor
from di.types.providers import (
DependencyProvider,
DependencyProviderType,
DependencyType,
)
from di.types.scopes import Scope
from di.types.solved import SolvedDependency

Dependency = Any

_BoundTaskGroup = List[Callable[[], Union[Coroutine[Any, Any, None], None]]]


class Container:
_context: ContextVar[ContainerState]
Expand Down Expand Up @@ -270,7 +277,7 @@ def execute_sync(
solved: SolvedDependency[DependencyType],
*,
validate_scopes: bool = True,
values: Optional[Values] = None,
values: Optional[Mapping[DependencyProvider, Any]] = None,
) -> DependencyType:
"""Execute an already solved dependency.
Expand All @@ -291,18 +298,33 @@ def execute_sync(
)
executor = cast(SyncExecutor, self._executor)

stateful_tasks = [
[functools.partial(t.compute, self._state, results) for t in group]
for group in tasks
]
return executor.execute_sync(stateful_tasks, lambda: results[tasks[-1][0]], values) # type: ignore
bound_groups: List[_BoundTaskGroup] = []
for task_group in tasks:
group: _BoundTaskGroup = []
for task in task_group:
assert task.dependant.call is not None
if task.dependant.call in values:
results[task] = values[task.dependant.call]
elif task.dependant.share and self._state.cached_values.contains(
task.dependant.call
):
results[task] = self._state.cached_values.get(
task.dependant.call
)
else:
group.append(
functools.partial(task.compute, self._state, results)
)
if group:
bound_groups.append(group)
return executor.execute_sync(bound_groups, lambda: results[tasks[-1][0]]) # type: ignore

async def execute_async(
self,
solved: SolvedDependency[DependencyType],
*,
validate_scopes: bool = True,
values: Optional[Values] = None,
values: Optional[Mapping[DependencyProvider, Any]] = None,
) -> DependencyType:
"""Execute an already solved dependency.
Expand All @@ -323,10 +345,23 @@ async def execute_async(
)
executor = cast(AsyncExecutor, self._executor)

stateful_tasks: List[
List[Callable[[], Union[Coroutine[Any, Any, None], None]]]
] = [
[functools.partial(t.compute, self._state, results) for t in group]
for group in tasks
]
return await executor.execute_async(stateful_tasks, lambda: results[tasks[-1][0]], values) # type: ignore
bound_groups: List[_BoundTaskGroup] = []
for task_group in tasks:
group: _BoundTaskGroup = []
for task in task_group:
assert task.dependant.call is not None
if task.dependant.call in values:
results[task] = values[task.dependant.call]
elif task.dependant.share and self._state.cached_values.contains(
task.dependant.call
):
results[task] = self._state.cached_values.get(
task.dependant.call
)
else:
group.append(
functools.partial(task.compute, self._state, results)
)
if group:
bound_groups.append(group)
return await executor.execute_async(bound_groups, lambda: results[tasks[-1][0]]) # type: ignore
15 changes: 6 additions & 9 deletions di/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import anyio.abc

from di._concurrency import curry_context, gurantee_awaitable
from di.types.executor import AsyncExecutor, SyncExecutor, Task, Values
from di.types.executor import AsyncExecutor, SyncExecutor, Task

ResultType = typing.TypeVar("ResultType")

Expand All @@ -16,11 +16,10 @@ def execute_sync(
self,
tasks: typing.List[typing.List[Task]],
get_result: typing.Callable[[], ResultType],
values: Values,
) -> ResultType:
for task_group in tasks:
for task in task_group:
result = task(values)
result = task()
if inspect.isawaitable(result):
raise TypeError("Cannot execute async dependencies in execute_sync")
return get_result()
Expand All @@ -31,7 +30,6 @@ async def execute_async(
self,
tasks: typing.List[typing.List[Task]],
get_result: typing.Callable[[], ResultType],
values: Values,
) -> ResultType:
# note: there are 2 task group concepts in this function that should not be confused
# to di, tasks groups are a set of Task's that can be executed in parallel
Expand All @@ -43,9 +41,9 @@ async def execute_async(
tg = anyio.create_task_group()
async with tg:
for task in task_group:
tg.start_soon(gurantee_awaitable(task), values) # type: ignore
tg.start_soon(gurantee_awaitable(task)) # type: ignore
else:
await gurantee_awaitable(next(iter(task_group)))(values)
await gurantee_awaitable(next(iter(task_group)))()
return get_result()


Expand All @@ -57,7 +55,6 @@ def execute_sync(
self,
tasks: typing.List[typing.List[Task]],
get_result: typing.Callable[[], ResultType],
values: Values,
) -> ResultType:
for task_group in tasks:
if len(task_group) > 1:
Expand All @@ -67,7 +64,7 @@ def execute_sync(
]
] = []
for task in task_group:
futures.append(self._threadpool.submit(curry_context(task), values))
futures.append(self._threadpool.submit(curry_context(task)))
for future in concurrent.futures.as_completed(futures):
exc = future.exception()
if exc is not None:
Expand All @@ -77,7 +74,7 @@ def execute_sync(
"Cannot execute async dependencies in execute_sync"
)
else:
v = task_group[0](values)
v = task_group[0]()
if inspect.isawaitable(v):
raise TypeError("Cannot execute async dependencies in execute_sync")
return get_result()
Expand Down
10 changes: 2 additions & 8 deletions di/types/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import sys
from typing import Any, Awaitable, Callable, List, Mapping, TypeVar, Union

from di.types.providers import DependencyProvider
from typing import Any, Awaitable, Callable, List, TypeVar, Union

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand All @@ -11,17 +9,14 @@
ResultType = TypeVar("ResultType")
Dependency = Any

Values = Mapping[DependencyProvider, Dependency]

Task = Callable[[Values], Union[None, Awaitable[None]]]
Task = Callable[[], Union[None, Awaitable[None]]]


class SyncExecutor(Protocol):
def execute_sync(
self,
tasks: List[List[Task]],
get_result: Callable[[], ResultType],
values: Values,
) -> ResultType:
raise NotImplementedError

Expand All @@ -31,6 +26,5 @@ async def execute_async(
self,
tasks: List[List[Task]],
get_result: Callable[[], ResultType],
values: Values,
) -> ResultType:
raise NotImplementedError
19 changes: 3 additions & 16 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,24 +332,10 @@ async def test_concurrent_executions_share_cache(
def get_obj() -> object:
return object()

# use dependencies as delyas to ensure that
# collect2 and collect1 do not execute at the exact same time
# otherwise they might not share the cache

async def dep1() -> None:
await anyio.sleep(5e-3)

async def dep2() -> None:
...

async def collect1(
one: None = Depends(dep1), obj: object = Depends(get_obj, scope=scope)
) -> None:
async def collect1(obj: object = Depends(get_obj, scope=scope)) -> None:
objects.append(obj)

async def collect2(
two: None = Depends(dep2), obj: object = Depends(get_obj, scope=scope)
) -> None:
async def collect2(obj: object = Depends(get_obj, scope=scope)) -> None:
objects.append(obj)

container = Container()
Expand All @@ -359,6 +345,7 @@ async def collect2(
async with container.enter_global_scope("global"):
async with anyio.create_task_group() as tg:
tg.start_soon(functools.partial(container.execute_async, solved1)) # type: ignore
await anyio.sleep(0.05)
tg.start_soon(functools.partial(container.execute_async, solved2)) # type: ignore

assert (objects[0] is objects[1]) is shared
Expand Down
10 changes: 5 additions & 5 deletions tests/test_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import pytest

from di.executors import ConcurrentSyncExecutor, DefaultExecutor, SimpleSyncExecutor
from di.types.executor import SyncExecutor, Values
from di.types.executor import SyncExecutor


async def task(values: Values) -> None:
async def task() -> None:
...


Expand All @@ -22,18 +22,18 @@ def test_executing_async_dependencies_in_sync_executor(
with pytest.raises(
TypeError, match="Cannot execute async dependencies in execute_sync"
):
exc.execute_sync(tasks, lambda: None, {}) # type: ignore
exc.execute_sync(tasks, lambda: None) # type: ignore


def test_simple_sync_executor():
class Dep:
called = False

def __call__(self, values: Values) -> None:
def __call__(self) -> None:
self.called = True

tasks = [[Dep()], [Dep(), Dep()]]

exc = SimpleSyncExecutor()

exc.execute_sync(tasks, lambda: None, {})
exc.execute_sync(tasks, lambda: None) # type: ignore

0 comments on commit 2bb419d

Please sign in to comment.