diff --git a/modal/_utils/async_utils.py b/modal/_utils/async_utils.py index f4c67e466..c3f48c744 100644 --- a/modal/_utils/async_utils.py +++ b/modal/_utils/async_utils.py @@ -490,15 +490,15 @@ async def aclosing( await agen.aclose() -async def sync_or_async_iter(iterator: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]: - if hasattr(iterator, "__aiter__"): - async for item in typing.cast(AsyncIterable[T], iterator): +async def sync_or_async_iter(iterable: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]: + if hasattr(iterable, "__aiter__"): + async for item in typing.cast(AsyncIterable[T], iterable): yield item else: # This intentionally could block the event loop for the duration of calling __iter__ and __next__, # so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users # # w/ async code (but they can work around it by always using async iterators) - for item in typing.cast(Iterable[T], iterator): + for item in typing.cast(Iterable[T], iterable): yield item @@ -510,3 +510,37 @@ async def async_zip(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGener yield tuple(items) except StopAsyncIteration: break + + +async def async_merge(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]: + queue: asyncio.Queue[Tuple[int, Tuple[str, Union[T, Exception, None]]]] = asyncio.Queue() + + async def producer(producer_id: int, iterable: Union[AsyncIterable[T], Iterable[T]]): + try: + async for item in sync_or_async_iter(iterable): + await queue.put((producer_id, ("value", item))) + except Exception as e: + await queue.put((producer_id, ("exception", e))) + finally: + await queue.put((producer_id, ("stop", None))) + + tasks = [asyncio.create_task(producer(i, it)) for i, it in enumerate(inputs)] + active_producers = set(range(len(inputs))) + + try: + while active_producers: + producer_id, (event_type, item) = await queue.get() + if event_type == "exception": + raise typing.cast(Exception, item) + elif event_type == "stop": + active_producers.remove(producer_id) + else: + yield typing.cast(T, item) + finally: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]: + yield await awaitable() diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index c9ef4aeb2..329645872 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -4,7 +4,7 @@ import os from enum import Enum from pathlib import Path, PurePosixPath -from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type +from typing import Any, AsyncGenerator, Callable, Dict, List, Literal, Optional, Tuple, Type from grpclib import GRPCError from grpclib.exceptions import StreamTerminatedError @@ -355,7 +355,7 @@ def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool: async def _stream_function_call_data( client, function_call_id: str, variant: Literal["data_in", "data_out"] -) -> AsyncIterator[Any]: +) -> AsyncGenerator[Any, None]: """Read from the `data_in` or `data_out` stream of a function call.""" last_index = 0 diff --git a/modal/functions.py b/modal/functions.py index 9acf0ea7a..40d32c3e5 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -23,7 +23,6 @@ ) import typing_extensions -from aiostream import stream from google.protobuf.message import Message from grpclib import GRPCError, Status from synchronicity.combined_types import MethodWithAio @@ -38,6 +37,9 @@ from ._serialization import serialize, serialize_proto_params from ._utils.async_utils import ( TaskContext, + aclosing, + async_merge, + callable_to_agen, synchronize_api, synchronizer, warn_if_generator_is_not_consumed, @@ -205,22 +207,22 @@ async def poll_function(self, timeout: Optional[float] = None): ) async def run_generator(self): - data_stream = _stream_function_call_data(self.client, self.function_call_id, variant="data_out") - combined_stream = stream.merge(data_stream, stream.call(self.run_function)) # type: ignore - items_received = 0 items_total: Union[int, None] = None # populated when self.run_function() completes - async with combined_stream.stream() as streamer: - async for item in streamer: - if isinstance(item, api_pb2.GeneratorDone): - items_total = item.items_total - else: - yield item - items_received += 1 - # The comparison avoids infinite loops if a non-deterministic generator is retried - # and produces less data in the second run than what was already sent. - if items_total is not None and items_received >= items_total: - break + async with aclosing( + _stream_function_call_data(self.client, self.function_call_id, variant="data_out") + ) as data_stream: + async with aclosing(async_merge(data_stream, callable_to_agen(self.run_function))) as streamer: + async for item in streamer: + if isinstance(item, api_pb2.GeneratorDone): + items_total = item.items_total + else: + yield item + items_received += 1 + # The comparison avoids infinite loops if a non-deterministic generator is retried + # and produces less data in the second run than what was already sent. + if items_total is not None and items_received >= items_total: + break # Wrapper type for api_pb2.FunctionStats diff --git a/modal/parallel_map.py b/modal/parallel_map.py index a098899ff..f4ec67bdf 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -11,6 +11,7 @@ from modal._utils.async_utils import ( AsyncOrSyncIterable, aclosing, + async_merge, async_zip, queue_batch_iterator, synchronize_api, @@ -249,12 +250,13 @@ async def poll_outputs(): assert len(received_outputs) == 0 - response_gen = stream.merge(drain_input_generator(), pump_inputs(), poll_outputs()) - - async with response_gen.stream() as streamer: - async for response in streamer: - if response is not None: - yield response.value + async with aclosing(drain_input_generator()) as drainer, aclosing(pump_inputs()) as pump, aclosing( + poll_outputs() + ) as poller: + async with aclosing(async_merge(drainer, pump, poller)) as streamer: + async for response in streamer: + if response is not None: + yield response.value @warn_if_generator_is_not_consumed(function_name="Function.map") diff --git a/test/async_utils_test.py b/test/async_utils_test.py index 40e6a2b4b..71b678125 100644 --- a/test/async_utils_test.py +++ b/test/async_utils_test.py @@ -12,7 +12,9 @@ from modal._utils.async_utils import ( TaskContext, aclosing, + async_merge, async_zip, + callable_to_agen, queue_batch_iterator, retry, sync_or_async_iter, @@ -382,3 +384,75 @@ async def gen2(): result.append(item) assert result == [(1, 2)] + + +@pytest.mark.asyncio +async def test_async_merge(): + result = [] + states = [] + + gen1_event = asyncio.Event() + gen2_event = asyncio.Event() + gen3_event = asyncio.Event() + gen4_event = asyncio.Event() + + async def gen1(): + states.append("gen1 enter") + try: + gen1_event.set() + await gen2_event.wait() + yield 1 + gen3_event.set() + await gen4_event.wait() + yield 2 + finally: + states.append("gen1 exit") + + async def gen2(): + states.append("gen2 enter") + try: + await gen1_event.wait() + yield 3 + gen2_event.set() + await gen3_event.wait() + yield 4 + gen4_event.set() + finally: + states.append("gen2 exit") + + async for item in async_merge(gen1(), gen2()): + result.append(item) + + assert result == [3, 1, 4, 2] + assert sorted(states) == [ + "gen1 enter", + "gen1 exit", + "gen2 enter", + "gen2 exit", + ] + + result.clear() + states.clear() + + async for item in async_merge(gen1(), gen2()): + break + + assert result == [] + assert sorted(states) == [ + "gen1 enter", + "gen1 exit", + "gen2 enter", + "gen2 exit", + ] + + +@pytest.mark.asyncio +async def test_awaitable_to_aiter(): + async def foo(): + await asyncio.sleep(0.1) + return 42 + + result = [] + async for item in callable_to_agen(foo): + result.append(item) + assert result == [await foo()]