Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOD-2454: remove aiostream dependency: async_merge #2344

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,15 +490,15 @@ async def aclosing(
await agen.aclose()


async def sync_or_async_iter(iterator: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
if hasattr(iterator, "__aiter__"):
async for item in typing.cast(AsyncIterable[T], iterator):
async def sync_or_async_iter(iterable: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
if hasattr(iterable, "__aiter__"):
async for item in typing.cast(AsyncIterable[T], iterable):
yield item
else:
# This intentionally could block the event loop for the duration of calling __iter__ and __next__,
# so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users #
# w/ async code (but they can work around it by always using async iterators)
for item in typing.cast(Iterable[T], iterator):
for item in typing.cast(Iterable[T], iterable):
yield item


Expand All @@ -510,3 +510,37 @@ async def async_zip(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGener
yield tuple(items)
except StopAsyncIteration:
break


async def async_merge(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]:
queue: asyncio.Queue[Tuple[int, Tuple[str, Union[T, Exception, None]]]] = asyncio.Queue()

async def producer(producer_id: int, iterable: Union[AsyncIterable[T], Iterable[T]]):
try:
async for item in sync_or_async_iter(iterable):
await queue.put((producer_id, ("value", item)))
except Exception as e:
await queue.put((producer_id, ("exception", e)))
finally:
await queue.put((producer_id, ("stop", None)))

tasks = [asyncio.create_task(producer(i, it)) for i, it in enumerate(inputs)]
active_producers = set(range(len(inputs)))

try:
while active_producers:
producer_id, (event_type, item) = await queue.get()
if event_type == "exception":
raise typing.cast(Exception, item)
elif event_type == "stop":
active_producers.remove(producer_id)
else:
yield typing.cast(T, item)
finally:
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
yield await awaitable()
4 changes: 2 additions & 2 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from enum import Enum
from pathlib import Path, PurePosixPath
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type
from typing import Any, AsyncGenerator, Callable, Dict, List, Literal, Optional, Tuple, Type

from grpclib import GRPCError
from grpclib.exceptions import StreamTerminatedError
Expand Down Expand Up @@ -355,7 +355,7 @@ def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:

async def _stream_function_call_data(
client, function_call_id: str, variant: Literal["data_in", "data_out"]
) -> AsyncIterator[Any]:
) -> AsyncGenerator[Any, None]:
"""Read from the `data_in` or `data_out` stream of a function call."""
last_index = 0

Expand Down
32 changes: 17 additions & 15 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)

import typing_extensions
from aiostream import stream
from google.protobuf.message import Message
from grpclib import GRPCError, Status
from synchronicity.combined_types import MethodWithAio
Expand All @@ -38,6 +37,9 @@
from ._serialization import serialize, serialize_proto_params
from ._utils.async_utils import (
TaskContext,
aclosing,
async_merge,
callable_to_agen,
synchronize_api,
synchronizer,
warn_if_generator_is_not_consumed,
Expand Down Expand Up @@ -205,22 +207,22 @@ async def poll_function(self, timeout: Optional[float] = None):
)

async def run_generator(self):
data_stream = _stream_function_call_data(self.client, self.function_call_id, variant="data_out")
combined_stream = stream.merge(data_stream, stream.call(self.run_function)) # type: ignore

items_received = 0
items_total: Union[int, None] = None # populated when self.run_function() completes
async with combined_stream.stream() as streamer:
async for item in streamer:
if isinstance(item, api_pb2.GeneratorDone):
items_total = item.items_total
else:
yield item
items_received += 1
# The comparison avoids infinite loops if a non-deterministic generator is retried
# and produces less data in the second run than what was already sent.
if items_total is not None and items_received >= items_total:
break
async with aclosing(
_stream_function_call_data(self.client, self.function_call_id, variant="data_out")
) as data_stream:
async with aclosing(async_merge(data_stream, callable_to_agen(self.run_function))) as streamer:
async for item in streamer:
if isinstance(item, api_pb2.GeneratorDone):
items_total = item.items_total
else:
yield item
items_received += 1
# The comparison avoids infinite loops if a non-deterministic generator is retried
# and produces less data in the second run than what was already sent.
if items_total is not None and items_received >= items_total:
break


# Wrapper type for api_pb2.FunctionStats
Expand Down
14 changes: 8 additions & 6 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from modal._utils.async_utils import (
AsyncOrSyncIterable,
aclosing,
async_merge,
async_zip,
queue_batch_iterator,
synchronize_api,
Expand Down Expand Up @@ -249,12 +250,13 @@ async def poll_outputs():

assert len(received_outputs) == 0

response_gen = stream.merge(drain_input_generator(), pump_inputs(), poll_outputs())

async with response_gen.stream() as streamer:
async for response in streamer:
if response is not None:
yield response.value
async with aclosing(drain_input_generator()) as drainer, aclosing(pump_inputs()) as pump, aclosing(
poll_outputs()
) as poller:
async with aclosing(async_merge(drainer, pump, poller)) as streamer:
async for response in streamer:
if response is not None:
yield response.value


@warn_if_generator_is_not_consumed(function_name="Function.map")
Expand Down
74 changes: 74 additions & 0 deletions test/async_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from modal._utils.async_utils import (
TaskContext,
aclosing,
async_merge,
async_zip,
callable_to_agen,
queue_batch_iterator,
retry,
sync_or_async_iter,
Expand Down Expand Up @@ -382,3 +384,75 @@ async def gen2():
result.append(item)

assert result == [(1, 2)]


@pytest.mark.asyncio
async def test_async_merge():
result = []
states = []

gen1_event = asyncio.Event()
gen2_event = asyncio.Event()
gen3_event = asyncio.Event()
gen4_event = asyncio.Event()

async def gen1():
states.append("gen1 enter")
try:
gen1_event.set()
await gen2_event.wait()
yield 1
gen3_event.set()
await gen4_event.wait()
yield 2
finally:
states.append("gen1 exit")

async def gen2():
states.append("gen2 enter")
try:
await gen1_event.wait()
yield 3
gen2_event.set()
await gen3_event.wait()
yield 4
gen4_event.set()
finally:
states.append("gen2 exit")

async for item in async_merge(gen1(), gen2()):
result.append(item)

assert result == [3, 1, 4, 2]
assert sorted(states) == [
"gen1 enter",
"gen1 exit",
"gen2 enter",
"gen2 exit",
]

result.clear()
states.clear()

async for item in async_merge(gen1(), gen2()):
break

assert result == []
assert sorted(states) == [
"gen1 enter",
"gen1 exit",
"gen2 enter",
"gen2 exit",
]


@pytest.mark.asyncio
async def test_awaitable_to_aiter():
async def foo():
await asyncio.sleep(0.1)
return 42

result = []
async for item in callable_to_agen(foo):
result.append(item)
assert result == [await foo()]
Loading