Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOD-2454: update async utils #2397

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import typing
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -403,6 +404,7 @@ async def wrapper():

T = TypeVar("T")
P = ParamSpec("P")
V = TypeVar("V")


def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
Expand Down Expand Up @@ -495,47 +497,94 @@ async def sync_or_async_iter(iterable: Union[Iterable[T], AsyncIterable[T]]) ->
async for item in typing.cast(AsyncIterable[T], iterable):
yield item
else:
assert hasattr(iterable, "__iter__"), "sync_or_async_iter requires an iterable or async iterable"
# This intentionally could block the event loop for the duration of calling __iter__ and __next__,
# so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users #
# w/ async code (but they can work around it by always using async iterators)
for item in typing.cast(Iterable[T], iterable):
yield item


async def async_zip(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[Tuple[T, ...], None]:
generators = [sync_or_async_iter(it) for it in inputs]
while True:
@typing.overload
def async_zip(
i1: Union[AsyncIterable[T], Iterable[T]], i2: Union[AsyncIterable[V], Iterable[V]], /
) -> AsyncGenerator[Tuple[T, V], None]:
...


@typing.overload
def async_zip(*iterables: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[Tuple[T, ...], None]:
...


async def async_zip(*iterables):
tasks = []
generators = [sync_or_async_iter(it) for it in iterables]
try:
while True:
try:

async def next_item(gen):
return await gen.__anext__()

tasks = [asyncio.create_task(next_item(gen)) for gen in generators]
items = await asyncio.gather(*tasks)
yield tuple(items)
except StopAsyncIteration:
break
finally:
cancelled_tasks = []
for task in tasks:
if not task.done():
task.cancel()
cancelled_tasks.append(task)
try:
items = await asyncio.gather(*(it.__anext__() for it in generators))
yield tuple(items)
except StopAsyncIteration:
break
await asyncio.gather(*cancelled_tasks)
except asyncio.CancelledError:
pass


@dataclass
class ValueWrapper(typing.Generic[T]):
value: T


@dataclass
class ExceptionWrapper:
value: Exception


class StopSentinelType:
...


STOP_SENTINEL = StopSentinelType()


async def async_merge(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]:
queue: asyncio.Queue[Tuple[int, Tuple[str, Union[T, Exception, None]]]] = asyncio.Queue()
queue: asyncio.Queue[Tuple[int, Union[ValueWrapper[T], ExceptionWrapper, StopSentinelType]]] = asyncio.Queue()

async def producer(producer_id: int, iterable: Union[AsyncIterable[T], Iterable[T]]):
try:
async for item in sync_or_async_iter(iterable):
await queue.put((producer_id, ("value", item)))
await queue.put((producer_id, ValueWrapper(item)))
except Exception as e:
await queue.put((producer_id, ("exception", e)))
await queue.put((producer_id, ExceptionWrapper(e)))
finally:
await queue.put((producer_id, ("stop", None)))
await queue.put((producer_id, STOP_SENTINEL))

tasks = [asyncio.create_task(producer(i, it)) for i, it in enumerate(inputs)]
active_producers = set(range(len(inputs)))

try:
while active_producers:
producer_id, (event_type, item) = await queue.get()
if event_type == "exception":
raise typing.cast(Exception, item)
elif event_type == "stop":
producer_id, item = await queue.get()
if isinstance(item, ExceptionWrapper):
raise item.value
elif isinstance(item, StopSentinelType):
active_producers.remove(producer_id)
else:
yield typing.cast(T, item)
yield item.value
finally:
for task in tasks:
task.cancel()
Expand Down
23 changes: 11 additions & 12 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,17 @@ async def run_generator(self):
items_total: Union[int, None] = None # populated when self.run_function() completes
async with aclosing(
_stream_function_call_data(self.client, self.function_call_id, variant="data_out")
) as data_stream:
async with aclosing(async_merge(data_stream, callable_to_agen(self.run_function))) as streamer:
async for item in streamer:
if isinstance(item, api_pb2.GeneratorDone):
items_total = item.items_total
else:
yield item
items_received += 1
# The comparison avoids infinite loops if a non-deterministic generator is retried
# and produces less data in the second run than what was already sent.
if items_total is not None and items_received >= items_total:
break
) as data_stream, aclosing(async_merge(data_stream, callable_to_agen(self.run_function))) as streamer:
async for item in streamer:
if isinstance(item, api_pb2.GeneratorDone):
items_total = item.items_total
else:
yield item
items_received += 1
# The comparison avoids infinite loops if a non-deterministic generator is retried
# and produces less data in the second run than what was already sent.
if items_total is not None and items_received >= items_total:
break


# Wrapper type for api_pb2.FunctionStats
Expand Down
12 changes: 6 additions & 6 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
async_merge,
async_zip,
queue_batch_iterator,
sync_or_async_iter,
synchronize_api,
synchronizer,
warn_if_generator_is_not_consumed,
Expand Down Expand Up @@ -252,11 +253,10 @@ async def poll_outputs():

async with aclosing(drain_input_generator()) as drainer, aclosing(pump_inputs()) as pump, aclosing(
poll_outputs()
) as poller:
async with aclosing(async_merge(drainer, pump, poller)) as streamer:
async for response in streamer:
if response is not None:
yield response.value
) as poller, aclosing(async_merge(drainer, pump, poller)) as streamer:
async for response in streamer:
if response is not None:
yield response.value


@warn_if_generator_is_not_consumed(function_name="Function.map")
Expand Down Expand Up @@ -390,7 +390,7 @@ async def _starmap_async(

async def feed_queue():
# This runs in a main thread event loop, so it doesn't block the synchronizer loop
async with stream.iterate(input_iterator).stream() as streamer:
async with aclosing(sync_or_async_iter(input_iterator)) as streamer:
async for args in streamer:
await raw_input_queue.put.aio((args, kwargs))
await raw_input_queue.put.aio(None) # end-of-input sentinel
Expand Down
Loading
Loading