Skip to content

Commit

Permalink
Fix bug in concrrency
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 28, 2021
1 parent 1cec450 commit 189fb4e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
13 changes: 10 additions & 3 deletions di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from anyio.abc import TaskGroup

from di._concurrency import bind_to_stack_as_awaitable, bind_to_stack_as_def_callable
from di._inspect import DependencyParameter, is_coroutine_callable
from di._inspect import (
DependencyParameter,
is_async_gen_callable,
is_coroutine_callable,
)
from di._state import ContainerState
from di._task import Task
from di._topsort import topsort
Expand Down Expand Up @@ -226,7 +230,10 @@ async def bound_call(*args: Any, **kwargs: Any) -> DependencyType:
# if this task is not being parallelized and we are dealing with a sync function
# then we can just execute it directly
# otherwise, we wrap it
if not parallel and not is_coroutine_callable(dependency.call):
if not parallel and not (
is_coroutine_callable(dependency.call)
or is_async_gen_callable(dependency.call)
):
res = bind_to_stack_as_def_callable(dependency.call, stack=stack)(
*args, **kwargs
)
Expand Down Expand Up @@ -299,7 +306,7 @@ async def execute_solved(
if validate_scopes:
self._validate_scopes(solved)
for group in reversed(solved.topsort):
parallel = len(group) == 1
parallel = len(group) > 1
for dep in group:
if dep not in tasks:
tasks[dep] = (
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "di"
version = "0.2.21"
version = "0.2.22"
description = "Autowiring dependency injection"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
77 changes: 77 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time
from typing import Any, AsyncGenerator, Generator

import anyio
import pytest

from di.container import Container
Expand Down Expand Up @@ -74,39 +76,51 @@ async def test_execute():
assert res.three.zero is res.zero


# TODO: a smarter way to detect concurrency
SLEEP_TIME = 0.05


def sync_callable_func() -> int:
time.sleep(SLEEP_TIME)
return 1


async def async_callable_func() -> int:
await anyio.sleep(SLEEP_TIME)
return 1


def sync_gen_func() -> Generator[int, None, None]:
time.sleep(SLEEP_TIME)
yield 1


async def async_gen_func() -> AsyncGenerator[int, None]:
await anyio.sleep(SLEEP_TIME)
yield 1


class SyncCallableCls:
def __call__(self) -> int:
time.sleep(SLEEP_TIME)
return 1


class AsyncCallableCls:
async def __call__(self) -> int:
await anyio.sleep(SLEEP_TIME)
return 1


class SyncGenCls:
def __call__(self) -> Generator[int, None, None]:
time.sleep(SLEEP_TIME)
yield 1


class AsyncGenCls:
async def __call__(self) -> AsyncGenerator[int, None]:
await anyio.sleep(0.05)
yield 1


Expand All @@ -122,8 +136,71 @@ async def __call__(self) -> AsyncGenerator[int, None]:
SyncGenCls(),
AsyncGenCls(),
],
ids=[
"sync_callable_func",
"async_callable_func",
"sync_gen_func",
"async_gen_func",
"SyncCallableCls",
"AsyncCallableCls",
"SyncGenCls",
"AsyncGenCls",
],
)
@pytest.mark.anyio
async def test_dependency_types(dep: Any):
container = Container()
assert (await container.execute(Dependant(dep))) == 1


@pytest.mark.parametrize(
"dep1",
[
sync_callable_func,
async_callable_func,
sync_gen_func,
async_gen_func,
SyncCallableCls(),
AsyncCallableCls(),
SyncGenCls(),
AsyncGenCls(),
],
ids=[
"sync_callable_func",
"async_callable_func",
"sync_gen_func",
"async_gen_func",
"SyncCallableCls",
"AsyncCallableCls",
"SyncGenCls",
"AsyncGenCls",
],
)
@pytest.mark.parametrize(
"dep2",
[
sync_callable_func,
async_callable_func,
sync_gen_func,
async_gen_func,
],
ids=[
"sync_callable_func",
"async_callable_func",
"sync_gen_func",
"async_gen_func",
],
)
@pytest.mark.anyio
async def test_concurrency(dep1: Any, dep2: Any):
async def collector(v1: int = Depends(dep1), v2: int = Depends(dep2)):
...

container = Container()
solved = container.solve(Dependant(collector))
start = time.time()
await container.execute_solved(solved, validate_scopes=False)
end = time.time()

elapsed = end - start
assert elapsed < 2 * SLEEP_TIME

0 comments on commit 189fb4e

Please sign in to comment.