Skip to content

Commit

Permalink
Better concurrency tests + fix concurrency bug
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 28, 2021
1 parent 189fb4e commit b66f6bf
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 34 deletions.
2 changes: 2 additions & 0 deletions di/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def __eq__(self, o: object) -> bool:
if type(self) != type(o):
return False
assert isinstance(o, type(self))
if self.shared is False or o.shared is False:
return False
return self.call is o.call

@join_docstring_from(DependantProtocol[Any].is_equivalent)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "di"
version = "0.2.22"
version = "0.2.23"
description = "Autowiring dependency injection"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
132 changes: 99 additions & 33 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import threading
import time
from contextlib import contextmanager
from typing import Any, AsyncGenerator, Generator

import anyio
Expand Down Expand Up @@ -76,51 +78,39 @@ async def test_execute():
assert res.three.zero is res.zero


# TODO: a smarter way to detect concurrency
SLEEP_TIME = 0.05


def sync_callable_func() -> int:
time.sleep(SLEEP_TIME)
return 1


async def async_callable_func() -> int:
await anyio.sleep(SLEEP_TIME)
return 1


def sync_gen_func() -> Generator[int, None, None]:
time.sleep(SLEEP_TIME)
yield 1


async def async_gen_func() -> AsyncGenerator[int, None]:
await anyio.sleep(SLEEP_TIME)
yield 1


class SyncCallableCls:
def __call__(self) -> int:
time.sleep(SLEEP_TIME)
return 1


class AsyncCallableCls:
async def __call__(self) -> int:
await anyio.sleep(SLEEP_TIME)
return 1


class SyncGenCls:
def __call__(self) -> Generator[int, None, None]:
time.sleep(SLEEP_TIME)
yield 1


class AsyncGenCls:
async def __call__(self) -> AsyncGenerator[int, None]:
await anyio.sleep(0.05)
yield 1


Expand Down Expand Up @@ -153,17 +143,85 @@ async def test_dependency_types(dep: Any):
assert (await container.execute(Dependant(dep))) == 1


class Counter:
def __init__(self) -> None:
self._lock = threading.Lock()
self._counter = 0

@property
def counter(self) -> int:
return self._counter

@contextmanager
def acquire(self) -> Generator[None, None, None]:
with self._lock:
self._counter += 1
yield


def sync_callable_func_slow(counter: Counter) -> None:
start = time.time()
with counter.acquire():
while counter.counter < 2:
if time.time() - start > 10:
raise TimeoutError("Tasks did not execute concurrently")
time.sleep(0.005)
return


async def async_callable_func_slow(counter: Counter) -> None:
start = time.time()
with counter.acquire():
while counter.counter < 2:
if time.time() - start > 10:
raise TimeoutError("Tasks did not execute concurrently")
await anyio.sleep(0.005)
return


def sync_gen_func_slow(counter: Counter) -> Generator[None, None, None]:
sync_callable_func_slow(counter)
yield None


async def async_gen_func_slow(counter: Counter) -> AsyncGenerator[None, None]:
await async_callable_func_slow(counter)
yield None


class SyncCallableClsSlow:
def __call__(self, counter: Counter) -> None:
sync_callable_func_slow(counter)


class AsyncCallableClsSlow:
async def __call__(self, counter: Counter) -> None:
await async_callable_func_slow(counter)


class SyncGenClsSlow:
def __call__(self, counter: Counter) -> Generator[None, None, None]:
sync_callable_func_slow(counter)
yield None


class AsyncGenClsSlow:
async def __call__(self, counter: Counter) -> AsyncGenerator[None, None]:
await async_callable_func_slow(counter)
yield None


@pytest.mark.parametrize(
"dep1",
[
sync_callable_func,
async_callable_func,
sync_gen_func,
async_gen_func,
SyncCallableCls(),
AsyncCallableCls(),
SyncGenCls(),
AsyncGenCls(),
sync_callable_func_slow,
async_callable_func_slow,
sync_gen_func_slow,
async_gen_func_slow,
SyncCallableClsSlow(),
AsyncCallableClsSlow(),
SyncGenClsSlow(),
AsyncGenClsSlow(),
],
ids=[
"sync_callable_func",
Expand All @@ -179,28 +237,36 @@ async def test_dependency_types(dep: Any):
@pytest.mark.parametrize(
"dep2",
[
sync_callable_func,
async_callable_func,
sync_gen_func,
async_gen_func,
sync_callable_func_slow,
async_callable_func_slow,
sync_gen_func_slow,
async_gen_func_slow,
SyncCallableClsSlow(),
AsyncCallableClsSlow(),
SyncGenClsSlow(),
AsyncGenClsSlow(),
],
ids=[
"sync_callable_func",
"async_callable_func",
"sync_gen_func",
"async_gen_func",
"SyncCallableCls",
"AsyncCallableCls",
"SyncGenCls",
"AsyncGenCls",
],
)
@pytest.mark.anyio
async def test_concurrency(dep1: Any, dep2: Any):
async def collector(v1: int = Depends(dep1), v2: int = Depends(dep2)):
...

container = Container()
solved = container.solve(Dependant(collector))
start = time.time()
await container.execute_solved(solved, validate_scopes=False)
end = time.time()

elapsed = end - start
assert elapsed < 2 * SLEEP_TIME
counter = Counter()
container.bind(Dependant(lambda: counter), Counter, scope=None)

async def collector(
a: None = Depends(dep1, shared=False), b: None = Depends(dep2, shared=False)
):
...

await container.execute(Dependant(collector))

0 comments on commit b66f6bf

Please sign in to comment.