diff --git a/aiomisc/__init__.py b/aiomisc/__init__.py index f23b08ff..f2cd86e8 100644 --- a/aiomisc/__init__.py +++ b/aiomisc/__init__.py @@ -5,6 +5,10 @@ from .context import Context, get_context from .counters import Statistic, get_statistics from .entrypoint import entrypoint, run +from .gather import ( + gather_graceful, gather_independent, gather_shackled, + wait_first_cancelled_or_exception, wait_graceful, +) from .iterator_wrapper import IteratorWrapper from .periodic import PeriodicCallback from .plugins import plugins @@ -58,6 +62,10 @@ "context_partial", "cutout", "entrypoint", + "gather", + "gather_graceful", + "gather_independent", + "gather_shackled", "get_context", "get_statistics", "io", @@ -76,4 +84,6 @@ "threaded_separate", "timeout", "wait_coroutine", + "wait_first_cancelled_or_exception", + "wait_graceful", ) diff --git a/aiomisc/gather.py b/aiomisc/gather.py new file mode 100644 index 00000000..31f16ab4 --- /dev/null +++ b/aiomisc/gather.py @@ -0,0 +1,254 @@ +import asyncio +from asyncio import ( + CancelledError, Future, Task, create_task, wait, AbstractEventLoop, +) +from contextlib import suppress +from itertools import filterfalse +from time import monotonic +from typing import ( + Any, Coroutine, Iterable, List, Optional, Sequence, Tuple, Union, +) + + +ToC = Union[Task, Coroutine] +dummy = object() + + +async def gather( + *tocs: Optional[ToC], + loop: Optional[AbstractEventLoop] = None, + return_exceptions: bool = False, +) -> list: + """ + Same as `asyncio.gather`, but allows to pass Nones untouched. + :param tocs: list of tasks/coroutines/Nones. + Nones are skipped and returned as is. + :param loop: + :param return_exceptions: whether to return exceptions + :returns: list of task/coroutine return values + """ + ret = [dummy if tfc else None for tfc in tocs] + res = await asyncio.gather( + *filter(None, tocs), loop=loop, return_exceptions=return_exceptions, + ) + for i, val in enumerate(ret): + if val is not None: + ret[i] = res.pop(0) + return ret + + +async def gather_shackled( + *tocs: Optional[ToC], + wait_cancelled: bool = False, +) -> list: + """ + Gather tasks dependently. If any of them is failed, then, other tasks + are cancelled and the original exception is raised. + :param tocs: list of tasks/coroutines/Nones. + Nones are skipped and returned as is. + :param wait_cancelled: whether to wait until all the other tasks are + cancelled upon any fail or external cancellation. + :returns: list of results (values or exceptions) + """ + return await gather_graceful( + primary=tocs, secondary=None, wait_cancelled=wait_cancelled, + ) + + +async def gather_independent( + *tocs: Optional[ToC], + wait_cancelled: bool = False, +) -> list: + """ + Gather tasks independently. If any of them is failed, then, other tasks + are NOT cancelled and processed as is. Any raised exceptions are returned. + :param tocs: list of tasks/coroutines/Nones. + Nones are skipped and returned as is. + :param wait_cancelled: whether to wait until all the other tasks are + cancelled upon primary fail or external cancellation. + :returns: list of results (values or exceptions) + """ + return await gather_graceful( + primary=None, secondary=tocs, wait_cancelled=wait_cancelled, + ) + + +async def gather_graceful( + primary: Optional[Sequence[Optional[ToC]]] = None, *, + secondary: Sequence[Optional[ToC]] = None, + wait_cancelled: bool = False, +) -> Union[list, Tuple[list, list]]: + """ + Gather tasks in two groups - primary and secondary. If any primary + is somehow failed, then, other tasks are cancelled. If secondary is failed, + then, nothing else is done. If any primary is failed, then, will raise the + first exception. Returns two lists of results, one for the primary tasks + (only values) and the other for the secondary tasks (values or exceptions). + :param primary: list of tasks/coroutines/Nones. + Nones are skipped and returned as is. + :param secondary: list of tasks/coroutines/Nones. + Nones are skipped and returned as is. + :param wait_cancelled: whether to wait until all the other tasks are + cancelled upon primary fail or external cancellation. + :returns: either primary results or secondary results or both + :raises ValueError: if both primary and secondary are None + """ + if primary is None and secondary is None: + raise ValueError("Either primary or secondary must not be None") + + tasks_primary = [ + create_task(toc) if isinstance(toc, Coroutine) else toc + for toc in primary or [] + ] + tasks_secondary = [ + create_task(toc) if isinstance(toc, Coroutine) else toc + for toc in secondary or [] + ] + + await wait_graceful( + filter(None, tasks_primary), + filter(None, tasks_secondary), + wait_cancelled=wait_cancelled, + ) + + ret_primary = [] + ret_secondary = [] + if tasks_primary: + ret_primary = await _gather_primary(tasks_primary) + if tasks_secondary: + ret_secondary = await _gather_secondary(tasks_secondary) + + if primary is None and secondary is not None: + return ret_secondary + + if primary is not None and secondary is None: + return ret_primary + + return ret_primary, ret_secondary + + +async def _gather_primary(tasks: Sequence[Optional[Task]]): + return [await task if task else None for task in tasks] + + +async def _gather_secondary(tasks: Sequence[Optional[Task]]): + ret: List[Optional[Any]] = [] + for task in tasks: + if task is None: + ret.append(None) + continue + try: + ret.append(await task) + except CancelledError as e: + # Check whether cancelled internally + if task.cancelled(): + ret.append(e) + continue + raise + except Exception as e: + ret.append(e) + return ret + + +async def wait_graceful( + primary: Optional[Iterable[Task]] = None, + secondary: Optional[Iterable[Task]] = None, + *, + wait_cancelled: bool = False, +): + """ + Waits for the tasks in two groups - primary and secondary. If any primary + is somehow failed, then, other tasks are cancelled. If secondary is failed, + then, nothing else is done. If any primary is failed, then, will raise the + first exception. + :param primary: optional iterable of primary tasks. + :param secondary: optional iterable of secondary tasks. + :param wait_cancelled: whether to wait until all the other tasks are + cancelled upon primary fail or external cancellation. + """ + await _wait_graceful( + primary or [], + secondary or [], + wait_cancelled=wait_cancelled, + ) + + +async def _wait_graceful( + primary: Iterable[Task], + secondary: Iterable[Task], + *, + wait_cancelled: bool = False, +): + primary, secondary = set(primary), set(secondary) + to_cancel = set() + failed_primary_task = None + try: + # If any primary tasks + if primary: + # Wait for primary tasks first + done, pending = await wait_first_cancelled_or_exception(primary) + # If any failed, cancel pending primary and all secondary task + failed_primary_task = _first_cancelled_or_exception(*done) + if failed_primary_task: + to_cancel.update(pending | secondary) + # If no primary task failed, wait for secondary tasks + if not failed_primary_task and secondary: + await wait(secondary) + except CancelledError: + # If was cancelled externally, cancel all tasks + to_cancel.update(primary | secondary) + + # Keep only pending tasks + to_cancel = set(filterfalse(Future.done, to_cancel)) + + # Cancel tasks + for task in to_cancel: + task.cancel() + + # Wait for cancelled tasks to complete suppressing external cancellation + if wait_cancelled and to_cancel: + with suppress(CancelledError): + await wait(to_cancel) + + # If some primary task failed or cancelled internally, raise exception + if failed_primary_task: + return failed_primary_task.result() + # If was cancelled externally + if to_cancel: + raise CancelledError + + +def _first_cancelled_or_exception(*fs: Future): + for fut in fs: + if fut.cancelled() or fut.exception(): + return fut + + +async def wait_first_cancelled_or_exception( + fs: Iterable[Future], *, + loop: Optional[AbstractEventLoop] = None, + timeout: float = None, +): + """ + Waits for the futures until any of them is cancelled or raises an exception + :param Iterable[Future] fs: iterable of future objects to wait for + :param loop: + :param float timeout: wait timeout, same as for `asyncio.wait` + """ + t = monotonic() + done = set() + pending = set(fs) + left = timeout + while pending and (left is None or left > 0): + if left is not None and timeout is not None: + left = timeout - (monotonic() - t) + d, p = await wait( + pending, timeout=left, + return_when=asyncio.FIRST_COMPLETED, + loop=loop, + ) + done.update(d) + pending = p + if _first_cancelled_or_exception(*d): + break + return done, pending diff --git a/docs/source/api/aiomisc.rst b/docs/source/api/aiomisc.rst index 702c9a9e..d5dd35f5 100644 --- a/docs/source/api/aiomisc.rst +++ b/docs/source/api/aiomisc.rst @@ -57,6 +57,13 @@ :members: :undoc-members: +``aiomisc.gather`` module ++++++++++++++++++++++++++++++ + +.. automodule:: aiomisc.gather + :members: + :undoc-members: + ``aiomisc.io`` module +++++++++++++++++++++ diff --git a/docs/source/gather.rst b/docs/source/gather.rst new file mode 100644 index 00000000..5c255346 --- /dev/null +++ b/docs/source/gather.rst @@ -0,0 +1,35 @@ +Gather +=============== + + + +.. code-block:: python + :name: gather + + import aiomisc + + + async def square(val): + return val ** 2 + + res = await aiomisc.gather( + square(2), None, square(3), + ) + assert res == [4, None, 9] + + +.. code-block:: python + :name: gather_exception + + import aiomisc + + + async def foo(): + return ValueError() + + res = await aiomisc.gather( + None, foo(), + return_exceptions=True, + ) + assert res[0] is None + assert isinstance(res[1], ValueError) diff --git a/tests/tests_gather/__init__.py b/tests/tests_gather/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_gather/conftest.py b/tests/tests_gather/conftest.py new file mode 100644 index 00000000..d2cc4147 --- /dev/null +++ b/tests/tests_gather/conftest.py @@ -0,0 +1,16 @@ +from asyncio import CancelledError, sleep + + +async def ok(delay=0): + await sleep(delay) + return 123 + + +async def fail(delay=0): + await sleep(delay) + raise ValueError + + +async def cancel(delay=0): + await sleep(delay) + raise CancelledError diff --git a/tests/tests_gather/test_gather_graceful.py b/tests/tests_gather/test_gather_graceful.py new file mode 100644 index 00000000..5a938a05 --- /dev/null +++ b/tests/tests_gather/test_gather_graceful.py @@ -0,0 +1,198 @@ +from asyncio import CancelledError, create_task, sleep +from contextlib import suppress + +import pytest + +from aiomisc import gather_graceful +from tests.tests_gather.conftest import ok, fail, cancel + + +async def test_gather_nones(): + async def foo(val): + return val + + primary, secondary = await gather_graceful( + [foo(1), foo(2), None, None, foo(3)], + secondary=[foo(4), None, foo(5), None], + ) + assert primary == [1, 2, None, None, 3] + assert secondary == [4, None, 5, None] + + +async def test_gather_tasks(): + async def foo(val): + return val + + primary, secondary = await gather_graceful( + [create_task(foo(1))], secondary=[create_task(foo(2))], + ) + assert primary == [1] + assert secondary == [2] + + +async def test_gather_empty(): + with pytest.raises(ValueError): + await gather_graceful() + + +async def test_gather_all_ok(): + ptask1 = create_task(ok()) + ptask2 = create_task(ok()) + stask1 = create_task(ok()) + stask2 = create_task(ok()) + + await gather_graceful([ptask1, ptask2], secondary=[stask1, stask2]) + assert ptask1.done() and await ptask1 + assert ptask2.done() and await ptask2 + assert stask1.done() and await stask1 + assert stask2.done() and await stask2 + + +async def test_gather_primary_ok(): + ptask1 = create_task(ok()) + ptask2 = create_task(ok()) + + await gather_graceful([ptask1, ptask2]) + assert ptask1.done() and await ptask1 + assert ptask2.done() and await ptask2 + + +async def test_gather_secondary_ok(): + stask1 = create_task(ok()) + stask2 = create_task(ok()) + + await gather_graceful(secondary=[stask1, stask2]) + assert stask1.done() and await stask1 + assert stask2.done() and await stask2 + + +async def test_gather_primary_failed_no_wait(): + ptask1 = create_task(fail()) + ptask2 = create_task(ok(100)) + stask = create_task(ok(100)) + + with pytest.raises(ValueError): + await gather_graceful([ptask1, ptask2], secondary=[stask]) + + assert ptask1.exception() + assert not stask.done() and not ptask2.done() + await sleep(0.01) + assert stask.cancelled() and ptask2.cancelled() + + +async def test_gather_primary_failed_wait_cancelled(): + ptask1 = create_task(fail()) + ptask2 = create_task(ok(100)) + stask = create_task(ok(100)) + + with pytest.raises(ValueError): + await gather_graceful( + [ptask1, ptask2], secondary=[stask], + wait_cancelled=True, + ) + + assert ptask1.exception() + assert stask.cancelled() and ptask2.cancelled() + + +@pytest.mark.parametrize("wait_cancelled", [False, True]) +async def test_gather_secondary_failed(wait_cancelled): + ptask = create_task(ok(0.01)) + stask1 = create_task(fail()) + stask2 = create_task(ok(0.01)) + + await gather_graceful( + [ptask], secondary=[stask1, stask2], + wait_cancelled=wait_cancelled, + ) + assert ptask.done() + assert stask1.exception() + assert stask2.done() and not stask2.cancelled() + + +async def test_gather_primary_cancelled_no_wait(): + ptask1 = create_task(cancel()) + ptask2 = create_task(ok(100)) + stask = create_task(ok(100)) + + with pytest.raises(CancelledError): + await gather_graceful([ptask1, ptask2], secondary=[stask]) + + assert ptask1.cancelled() + assert not stask.done() and not ptask2.done() + await sleep(0.01) + assert stask.cancelled() and ptask2.cancelled() + + +async def test_gather_primary_cancelled_wait_cancelled(): + async def cancel(): + raise CancelledError + + async def ok(): + await sleep(100) + + ptask1 = create_task(cancel()) + ptask2 = create_task(ok()) + stask = create_task(ok()) + + with pytest.raises(CancelledError): + await gather_graceful( + [ptask1, ptask2], secondary=[stask], + wait_cancelled=True, + ) + + assert ptask1.cancelled() + assert ptask2.cancelled() + assert stask.cancelled() + + +@pytest.mark.parametrize("wait_cancelled", [False, True]) +async def test_gather_secondary_cancelled(wait_cancelled): + ptask = create_task(ok(0.01)) + stask1 = create_task(cancel()) + stask2 = create_task(ok(0.01)) + + await gather_graceful( + [ptask], secondary=[stask1, stask2], + wait_cancelled=wait_cancelled, + ) + assert ptask.done() + assert stask1.cancelled() + assert stask2.done() and not stask2.cancelled() + + +async def test_gather_external_cancel_no_wait(): + ptask = create_task(ok(100)) + stask = create_task(ok(100)) + task = create_task(gather_graceful([ptask], secondary=[stask])) + + await sleep(0.01) + assert not ptask.done() and not stask.done() + + task.cancel() + await sleep(0.01) + assert ptask.cancelled() and stask.cancelled() + with pytest.raises(CancelledError): + await task + + +async def test_gather_external_cancel_wait_cancelled(): + async def cancel(): + with suppress(CancelledError): + await sleep(100) + await sleep(0.01) + + ptask = create_task(cancel()) + stask = create_task(cancel()) + task = create_task(gather_graceful([ptask], secondary=[stask])) + + await sleep(0.01) + assert not ptask.done() and not stask.done() + + task.cancel() + await sleep(0.005) + assert not ptask.done() and not stask.done() + await sleep(0.01) + assert ptask.done() and stask.done() + with pytest.raises(CancelledError): + await task diff --git a/tests/tests_gather/test_gather_independent.py b/tests/tests_gather/test_gather_independent.py new file mode 100644 index 00000000..fcfabe7c --- /dev/null +++ b/tests/tests_gather/test_gather_independent.py @@ -0,0 +1,99 @@ +from asyncio import CancelledError, create_task, sleep +from contextlib import suppress + +import pytest + +from aiomisc import gather_independent +from tests.tests_gather.conftest import ok, fail, cancel + + +async def test_gather_nones(): + async def foo(val): + return val + + res = await gather_independent( + foo(4), None, foo(5), None, + ) + assert res == [4, None, 5, None] + + +async def test_gather_tasks(): + async def foo(val): + return val + + res = await gather_independent(create_task(foo(2))) + assert res == [2] + + +async def test_gather_empty(): + assert not await gather_independent() + + +async def test_gather_ok(): + stask1 = create_task(ok()) + stask2 = create_task(ok()) + + await gather_independent(stask1, stask2) + assert stask1.done() and await stask1 + assert stask2.done() and await stask2 + + +@pytest.mark.parametrize("wait_cancelled", [False, True]) +async def test_gather_secondary_failed(wait_cancelled): + stask1 = create_task(fail()) + stask2 = create_task(ok(0.01)) + + await gather_independent( + stask1, stask2, + wait_cancelled=wait_cancelled, + ) + assert stask1.exception() + assert stask2.done() and not stask2.cancelled() + + +@pytest.mark.parametrize("wait_cancelled", [False, True]) +async def test_gather_secondary_cancelled(wait_cancelled): + stask1 = create_task(cancel()) + stask2 = create_task(ok(0.01)) + + await gather_independent( + stask1, stask2, + wait_cancelled=wait_cancelled, + ) + assert stask1.cancelled() + assert stask2.done() and not stask2.cancelled() + + +async def test_gather_external_cancel_no_wait(): + task = create_task(ok(100)) + gtask = create_task(gather_independent(task)) + + await sleep(0.01) + assert not task.done() + + gtask.cancel() + await sleep(0.01) + assert task.cancelled() + with pytest.raises(CancelledError): + await gtask + + +async def test_gather_external_cancel_wait_cancelled(): + async def cancel(): + with suppress(CancelledError): + await sleep(100) + await sleep(0.01) + + task = create_task(cancel()) + gtask = create_task(gather_independent(task)) + + await sleep(0.01) + assert not task.done() + + gtask.cancel() + await sleep(0.005) + assert not task.done() + await sleep(0.01) + assert task.done() + with pytest.raises(CancelledError): + await gtask diff --git a/tests/tests_gather/test_gather_shackled.py b/tests/tests_gather/test_gather_shackled.py new file mode 100644 index 00000000..d178ea53 --- /dev/null +++ b/tests/tests_gather/test_gather_shackled.py @@ -0,0 +1,125 @@ +from asyncio import CancelledError, create_task, sleep +from contextlib import suppress + +import pytest + +from aiomisc import gather_shackled +from tests.tests_gather.conftest import ok, fail, cancel + + +async def test_gather_nones(): + async def foo(val): + return val + + res = await gather_shackled(foo(1), foo(2), None, None, foo(3)) + assert res == [1, 2, None, None, 3] + + +async def test_gather_tasks(): + async def foo(val): + return val + + res = await gather_shackled(create_task(foo(1))) + assert res == [1] + + +async def test_gather_empty(): + assert not await gather_shackled() + + +async def test_gather_ok(): + ptask1 = create_task(ok()) + ptask2 = create_task(ok()) + + await gather_shackled(ptask1, ptask2) + assert ptask1.done() and await ptask1 + assert ptask2.done() and await ptask2 + + +async def test_gather_failed_no_wait(): + ptask1 = create_task(fail()) + ptask2 = create_task(ok(100)) + + with pytest.raises(ValueError): + await gather_shackled(ptask1, ptask2) + + assert ptask1.exception() + assert not ptask2.done() + await sleep(0.01) + assert ptask2.cancelled() + + +async def test_gather_failed_wait_cancelled(): + ptask1 = create_task(fail()) + ptask2 = create_task(ok(100)) + + with pytest.raises(ValueError): + await gather_shackled(ptask1, ptask2, wait_cancelled=True) + + assert ptask1.exception() + assert ptask2.cancelled() + + +async def test_gather_cancelled_no_wait(): + ptask1 = create_task(cancel()) + ptask2 = create_task(ok(100)) + + with pytest.raises(CancelledError): + await gather_shackled(ptask1, ptask2) + + assert ptask1.cancelled() + assert not ptask2.done() + await sleep(0.01) + assert ptask2.cancelled() + + +async def test_gather_cancelled_wait_cancelled(): + async def cancel(): + raise CancelledError + + async def ok(): + await sleep(100) + + ptask1 = create_task(cancel()) + ptask2 = create_task(ok()) + + with pytest.raises(CancelledError): + await gather_shackled(ptask1, ptask2, wait_cancelled=True) + + assert ptask1.cancelled() + assert ptask2.cancelled() + + +async def test_gather_external_cancel_no_wait(): + task = create_task(ok(100)) + gtask = create_task(gather_shackled(task)) + + await sleep(0.01) + assert not task.done() + + gtask.cancel() + await sleep(0.01) + assert task.cancelled() + with pytest.raises(CancelledError): + await gtask + + +async def test_gather_external_cancel_wait_cancelled(): + async def cancel(): + with suppress(CancelledError): + await sleep(100) + await sleep(0.01) + + task = create_task(cancel()) + gtask = create_task(gather_shackled(task)) + + await sleep(0.01) + assert not task.done() + + gtask.cancel() + await sleep(0.005) + assert not task.done() + await sleep(0.01) + assert task.done() + with pytest.raises(CancelledError): + await gtask diff --git a/tests/tests_gather/test_wait.py b/tests/tests_gather/test_wait.py new file mode 100644 index 00000000..1825ac25 --- /dev/null +++ b/tests/tests_gather/test_wait.py @@ -0,0 +1,68 @@ +from asyncio import create_task + +from aiomisc import wait_first_cancelled_or_exception +from tests.tests_gather.conftest import ok, fail, cancel + + +async def test_wait_ok(): + tasks = [ + create_task(ok(0.01)), + create_task(ok(0.02)), + create_task(ok(0.03)), + ] + done, pending = await wait_first_cancelled_or_exception(tasks) + assert done == set(tasks) + assert not pending + for task in done: + assert task.done() + assert await task + + +async def test_wait_timeout(): + tasks = [ + create_task(ok(0.01)), + create_task(ok(0.02)), + create_task(ok(0.03)), + ] + done, pending = await wait_first_cancelled_or_exception( + tasks, timeout=0.015, + ) + assert done == set(tasks[:1]) + assert pending == set(tasks[1:]) + for task in done: + assert task.done() + assert await task + for task in pending: + assert not task.done() + + +async def test_wait_fail(): + tasks = [ + create_task(ok(0.01)), + create_task(fail(0.02)), + create_task(cancel(0.03)), + ] + + done, pending = await wait_first_cancelled_or_exception(tasks) + assert done == {tasks[0], tasks[1]} + assert pending == {tasks[2]} + + assert tasks[0].done() and await tasks[0] + assert tasks[1].done() and tasks[1].exception() + assert not tasks[2].done() + + +async def test_wait_cancel(): + tasks = [ + create_task(ok(0.01)), + create_task(cancel(0.02)), + create_task(fail(0.03)), + ] + + done, pending = await wait_first_cancelled_or_exception(tasks) + assert done == {tasks[0], tasks[1]} + assert pending == {tasks[2]} + + assert tasks[0].done() and await tasks[0] + assert tasks[1].done() and tasks[1].cancelled() + assert not tasks[2].done() diff --git a/tests/tests_gather/test_wait_graceful.py b/tests/tests_gather/test_wait_graceful.py new file mode 100644 index 00000000..f24e06fc --- /dev/null +++ b/tests/tests_gather/test_wait_graceful.py @@ -0,0 +1,164 @@ +from asyncio import CancelledError, create_task, sleep +from contextlib import suppress + +import pytest + +from aiomisc import wait_graceful +from tests.tests_gather.conftest import ok, fail, cancel + + +async def test_wait_all_ok(): + ptask1 = create_task(ok()) + ptask2 = create_task(ok()) + stask1 = create_task(ok()) + stask2 = create_task(ok()) + + await wait_graceful([ptask1, ptask2], [stask1, stask2]) + assert ptask1.done() and await ptask1 + assert ptask2.done() and await ptask2 + assert stask1.done() and await stask1 + assert stask2.done() and await stask2 + + +async def test_wait_primary_ok(): + ptask1 = create_task(ok()) + ptask2 = create_task(ok()) + + await wait_graceful([ptask1, ptask2]) + assert ptask1.done() and await ptask1 + assert ptask2.done() and await ptask2 + + +async def test_wait_secondary_ok(): + stask1 = create_task(ok()) + stask2 = create_task(ok()) + + await wait_graceful(primary=[], secondary=[stask1, stask2]) + assert stask1.done() and await stask1 + assert stask2.done() and await stask2 + + +async def test_wait_primary_failed_no_wait(): + ptask1 = create_task(fail()) + ptask2 = create_task(ok(100)) + stask = create_task(ok(100)) + + with pytest.raises(ValueError): + await wait_graceful([ptask1, ptask2], [stask]) + + assert ptask1.exception() + assert not stask.done() and not ptask2.done() + await sleep(0.01) + assert stask.cancelled() and ptask2.cancelled() + + +async def test_wait_primary_failed_wait_cancelled(): + ptask1 = create_task(fail()) + ptask2 = create_task(ok(100)) + stask = create_task(ok(100)) + + with pytest.raises(ValueError): + await wait_graceful([ptask1, ptask2], [stask], wait_cancelled=True) + + assert ptask1.exception() + assert stask.cancelled() and ptask2.cancelled() + + +@pytest.mark.parametrize("wait_cancelled", [False, True]) +async def test_wait_secondary_failed(wait_cancelled): + ptask = create_task(ok(0.01)) + stask1 = create_task(fail()) + stask2 = create_task(ok(0.01)) + + await wait_graceful( + [ptask], [stask1, stask2], + wait_cancelled=wait_cancelled, + ) + assert ptask.done() + assert stask1.exception() + assert stask2.done() and not stask2.cancelled() + + +async def test_wait_primary_cancelled_no_wait(): + ptask1 = create_task(cancel()) + ptask2 = create_task(ok(100)) + stask = create_task(ok(100)) + + with pytest.raises(CancelledError): + await wait_graceful([ptask1, ptask2], [stask]) + + assert ptask1.cancelled() + assert not stask.done() and not ptask2.done() + await sleep(0.01) + assert stask.cancelled() and ptask2.cancelled() + + +async def test_wait_primary_cancelled_wait_cancelled(): + async def cancel(): + raise CancelledError + + async def ok(): + await sleep(100) + + ptask1 = create_task(cancel()) + ptask2 = create_task(ok()) + stask = create_task(ok()) + + with pytest.raises(CancelledError): + await wait_graceful([ptask1, ptask2], [stask], wait_cancelled=True) + + assert ptask1.cancelled() + assert ptask2.cancelled() + assert stask.cancelled() + + +@pytest.mark.parametrize("wait_cancelled", [False, True]) +async def test_wait_secondary_cancelled(wait_cancelled): + ptask = create_task(ok(0.01)) + stask1 = create_task(cancel()) + stask2 = create_task(ok(0.01)) + + await wait_graceful( + [ptask], [stask1, stask2], + wait_cancelled=wait_cancelled, + ) + assert ptask.done() + assert stask1.cancelled() + assert stask2.done() and not stask2.cancelled() + + +async def test_wait_external_cancel_no_wait(): + ptask = create_task(ok(100)) + stask = create_task(ok(100)) + task = create_task(wait_graceful([ptask], [stask])) + + await sleep(0.01) + assert not ptask.done() and not stask.done() + + task.cancel() + await sleep(0.01) + assert ptask.cancelled() and stask.cancelled() + with pytest.raises(CancelledError): + await task + + +async def test_wait_external_cancel_wait_cancelled(): + async def cancel(): + with suppress(CancelledError): + await sleep(100) + await sleep(0.01) + + ptask = create_task(cancel()) + stask = create_task(cancel()) + task = create_task(wait_graceful([ptask], [stask])) + + await sleep(0.01) + assert not ptask.done() and not stask.done() + + task.cancel() + await sleep(0.005) + assert not ptask.done() and not stask.done() + await sleep(0.01) + assert ptask.done() and stask.done() + with pytest.raises(CancelledError): + await task