diff --git a/modal/_utils/async_utils.py b/modal/_utils/async_utils.py index 57c0be63f..1dc3fdd77 100644 --- a/modal/_utils/async_utils.py +++ b/modal/_utils/async_utils.py @@ -3,6 +3,7 @@ import concurrent.futures import functools import inspect +import itertools import time import typing from contextlib import asynccontextmanager @@ -25,7 +26,7 @@ ) import synchronicity -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, assert_type from ..exception import InvalidError from .logger import logger @@ -561,35 +562,183 @@ class StopSentinelType: STOP_SENTINEL = StopSentinelType() -async def async_merge(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]: - queue: asyncio.Queue[Tuple[int, Union[ValueWrapper[T], ExceptionWrapper, StopSentinelType]]] = asyncio.Queue() +async def async_merge(*iterables: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]: + queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper, StopSentinelType]] = asyncio.Queue( + maxsize=len(iterables) * 10 + ) - async def producer(producer_id: int, iterable: Union[AsyncIterable[T], Iterable[T]]): + async def producer(iterable: Union[AsyncIterable[T], Iterable[T]]): try: async for item in sync_or_async_iter(iterable): - await queue.put((producer_id, ValueWrapper(item))) + await queue.put(ValueWrapper(item)) except Exception as e: - await queue.put((producer_id, ExceptionWrapper(e))) - finally: - await queue.put((producer_id, STOP_SENTINEL)) + await queue.put(ExceptionWrapper(e)) - tasks = [asyncio.create_task(producer(i, it)) for i, it in enumerate(inputs)] - active_producers = set(range(len(inputs))) + tasks = set([asyncio.create_task(producer(it)) for it in iterables]) + new_output_task = asyncio.create_task(queue.get()) try: - while active_producers: - producer_id, item = await queue.get() - if isinstance(item, ExceptionWrapper): - raise item.value - elif isinstance(item, StopSentinelType): - active_producers.remove(producer_id) - else: + while tasks: + done, _ = await asyncio.wait( + [*tasks, new_output_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if new_output_task in done: + item = new_output_task.result() + if isinstance(item, ValueWrapper): + yield item.value + else: + assert_type(item, ExceptionWrapper) + raise item.value + + new_output_task = asyncio.create_task(queue.get()) + + finished_producers = done & tasks + tasks -= finished_producers + for finished_producer in finished_producers: + # this is done in order to catch potential raised errors/cancellations + # from within worker tasks as soon as they happen. + await finished_producer + + while not queue.empty(): + item = await new_output_task + if isinstance(item, ValueWrapper): yield item.value + else: + assert_type(item, ExceptionWrapper) + raise item.value + + new_output_task = asyncio.create_task(queue.get()) + finally: + if not new_output_task.done(): + new_output_task.cancel() for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + if not task.done(): + try: + task.cancel() + await task + except asyncio.CancelledError: + pass async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]: yield await awaitable() + + +async def async_map( + input_iterable: Union[AsyncIterable[T], Iterable[T]], + async_mapper_func: Callable[[T], Awaitable[V]], + concurrency: int, +) -> AsyncGenerator[V, None]: + input_queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2) + output_queue: asyncio.Queue[Union[ValueWrapper[V], ExceptionWrapper]] = asyncio.Queue(maxsize=concurrency * 2) + + async def producer(): + async for item in sync_or_async_iter(input_iterable): + await input_queue.put(ValueWrapper(item)) + await input_queue.put(STOP_SENTINEL) + + async def worker(): + while True: + item = await input_queue.get() + try: + if isinstance(item, ValueWrapper): + try: + res = await async_mapper_func(item.value) + except Exception as e: + await output_queue.put(ExceptionWrapper(e)) + continue + + await output_queue.put(ValueWrapper(res)) + else: + assert_type(item, StopSentinelType) + break + finally: + input_queue.task_done() + + producer_task = asyncio.create_task(producer()) + worker_tasks = [asyncio.create_task(worker()) for _ in range(concurrency)] + + async def complete_map(): + await producer_task + await input_queue.join() + + complete_map_task = asyncio.create_task(complete_map()) + all_tasks = [*worker_tasks, complete_map_task] + + try: + new_output_task = asyncio.create_task(output_queue.get()) + try: + while True: + done, _ = await asyncio.wait( + [*all_tasks, new_output_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if new_output_task in done: + item = new_output_task.result() + if isinstance(item, ValueWrapper): + yield item.value + else: + assert_type(item, ExceptionWrapper) + raise item.value + + new_output_task = asyncio.create_task(output_queue.get()) + + finished_workers = done & set(worker_tasks) + for finished_worker in finished_workers: + # this is done in order to catch potential raised errors/cancellations + # from within worker tasks as soon as they happen. + await finished_worker + + if complete_map_task in done: + break + finally: + if not new_output_task.done(): + new_output_task.cancel() + + while not output_queue.empty(): + item = output_queue.get_nowait() + if isinstance(item, ValueWrapper): + yield item.value + else: + assert_type(item, ExceptionWrapper) + raise item.value + + await complete_map_task # raises potential errors from producer + + finally: + for task in all_tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +async def async_map_ordered( + input_iterable: Union[AsyncIterable[T], Iterable[T]], + async_mapper_func: Callable[[T], Awaitable[V]], + concurrency: int, +) -> AsyncGenerator[V, None]: + async def mapper_func_wrapper(tup: Tuple[int, T]) -> Tuple[int, V]: + return (tup[0], await async_mapper_func(tup[1])) + + async def counter() -> AsyncGenerator[int, None]: + for i in itertools.count(): + yield i + + next_idx = 0 + buffer = {} + + async with aclosing(counter()) as counter_gen, aclosing( + async_zip(counter_gen, input_iterable) + ) as zipped_input, aclosing(async_map(zipped_input, mapper_func_wrapper, concurrency)) as stream: + async for output_idx, output_item in stream: + buffer[output_idx] = output_item + + while next_idx in buffer: + yield buffer[next_idx] + del buffer[next_idx] + next_idx += 1 diff --git a/modal/mount.py b/modal/mount.py index 48e0306e1..5b37a73c2 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -12,7 +12,6 @@ from pathlib import Path, PurePosixPath from typing import AsyncGenerator, Callable, List, Optional, Tuple, Type, Union -import aiostream from google.protobuf.message import Message import modal.exception @@ -20,7 +19,7 @@ from modal_version import __version__ from ._resolver import Resolver -from ._utils.async_utils import synchronize_api +from ._utils.async_utils import aclosing, async_map, synchronize_api from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path from ._utils.grpc_utils import retry_transient_errors from ._utils.name_utils import check_object_name @@ -499,13 +498,14 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile: raise modal.exception.MountUploadTimeoutError(f"Mounting of {file_spec.source_description} timed out") - # Create the asynchronous iterable for file specs. - file_specs = aiostream.stream.iterate(_Mount._get_files(self._entries)) - # Upload files, or check if they already exist. n_concurrent_uploads = 512 - uploads_stream = aiostream.stream.map(file_specs, _put_file, task_limit=n_concurrent_uploads) - files: List[api_pb2.MountFile] = await aiostream.stream.list(uploads_stream) + files: List[api_pb2.MountFile] = [] + async with aclosing(_Mount._get_files(self._entries)) as files_stream, aclosing( + async_map(files_stream, _put_file, concurrency=n_concurrent_uploads) + ) as stream: + async for file in stream: + files.append(file) if not files: logger.warning(f"Mount of '{message_label}' is empty.") diff --git a/modal/network_file_system.py b/modal/network_file_system.py index f9039a5f2..f377260cc 100644 --- a/modal/network_file_system.py +++ b/modal/network_file_system.py @@ -5,7 +5,6 @@ from pathlib import Path, PurePosixPath from typing import Any, AsyncIterator, BinaryIO, Callable, List, Optional, Tuple, Type, Union -import aiostream from grpclib import GRPCError, Status from synchronicity.async_wrap import asynccontextmanager @@ -13,7 +12,7 @@ from modal_proto import api_pb2 from ._resolver import Resolver -from ._utils.async_utils import TaskContext, synchronize_api +from ._utils.async_utils import TaskContext, aclosing, async_map, sync_or_async_iter, synchronize_api from ._utils.blob_utils import LARGE_FILE_LIMIT, blob_iter, blob_upload_file from ._utils.grpc_utils import retry_transient_errors from ._utils.hash_utils import get_sha256_hex @@ -343,12 +342,14 @@ def gen_transfers(): relpath_str = subpath.relative_to(_local_path).as_posix() yield subpath, PurePosixPath(remote_path, relpath_str) - transfer_paths = aiostream.stream.iterate(gen_transfers()) - await aiostream.stream.map( - transfer_paths, - aiostream.async_(lambda paths: self.add_local_file(paths[0], paths[1], progress_cb)), - task_limit=20, - ) + async def _add_local_file(paths: Tuple[Path, PurePosixPath]) -> int: + return await self.add_local_file(paths[0], paths[1], progress_cb) + + async with aclosing(sync_or_async_iter(gen_transfers())) as transfer_paths, aclosing( + async_map(transfer_paths, _add_local_file, concurrency=20) + ) as stream: + async for _ in stream: # consume/execute the map + pass @live_method async def listdir(self, path: str) -> List[FileEntry]: diff --git a/modal/parallel_map.py b/modal/parallel_map.py index 729a97d97..e0ef84b2d 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -5,12 +5,12 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Set, Tuple -from aiostream import pipe, stream from grpclib import GRPCError, Status from modal._utils.async_utils import ( AsyncOrSyncIterable, aclosing, + async_map_ordered, async_merge, async_zip, queue_batch_iterator, @@ -111,17 +111,14 @@ async def input_iter(): while 1: raw_input = await raw_input_queue.get() if raw_input is None: # end of input sentinel - return + break yield raw_input # args, kwargs async def drain_input_generator(): # Parallelize uploading blobs - proto_input_stream = stream.iterate(input_iter()) | pipe.map( - create_input, # type: ignore[reportArgumentType] - ordered=True, - task_limit=BLOB_MAX_PARALLELISM, - ) - async with proto_input_stream.stream() as streamer: + async with aclosing(input_iter()) as input_streamer, aclosing( + async_map_ordered(input_streamer, create_input, concurrency=BLOB_MAX_PARALLELISM) + ) as streamer: async for item in streamer: await input_queue.put(item) @@ -229,14 +226,13 @@ async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> Tuple[int, Any]: return (item.idx, output) async def poll_outputs(): - outputs = stream.iterate(get_all_outputs_and_clean_up()) - outputs_fetched = outputs | pipe.map(fetch_output, ordered=True, task_limit=BLOB_MAX_PARALLELISM) # type: ignore - # map to store out-of-order outputs received received_outputs = {} output_idx = 0 - async with outputs_fetched.stream() as streamer: + async with aclosing(get_all_outputs_and_clean_up()) as outputs, aclosing( + async_map_ordered(outputs, fetch_output, concurrency=BLOB_MAX_PARALLELISM) + ) as streamer: async for idx, output in streamer: count_update() if not order_outputs: diff --git a/modal/volume.py b/modal/volume.py index 54ff99858..ddd4de73b 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -25,7 +25,6 @@ Union, ) -import aiostream from grpclib import GRPCError, Status from synchronicity.async_wrap import asynccontextmanager @@ -34,7 +33,7 @@ from modal_proto import api_pb2 from ._resolver import Resolver -from ._utils.async_utils import TaskContext, asyncnullcontext, synchronize_api +from ._utils.async_utils import TaskContext, aclosing, async_map, asyncnullcontext, synchronize_api from ._utils.blob_utils import ( FileUploadSpec, blob_iter, @@ -561,11 +560,14 @@ async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]: for fut in asyncio.as_completed(futs): yield await fut - # Compute checksums - files_stream = aiostream.stream.iterate(gen_file_upload_specs()) - # Upload files - uploads_stream = aiostream.stream.map(files_stream, self._upload_file, task_limit=20) - files: List[api_pb2.MountFile] = await aiostream.stream.list(uploads_stream) + # Compute checksums & Upload files + files: List[api_pb2.MountFile] = [] + async with aclosing(gen_file_upload_specs()) as files_stream, aclosing( + async_map(files_stream, self._upload_file, concurrency=20) + ) as stream: + async for item in stream: + files.append(item) + self._progress_cb(complete=True) request = api_pb2.VolumePutFilesRequest( diff --git a/setup.cfg b/setup.cfg index 3b698ecac..7660ee70e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,7 +20,6 @@ packages = find: python_requires = >=3.8 install_requires = aiohttp - aiostream~=0.5.2 certifi click>=8.1.0 fastapi diff --git a/test/async_utils_test.py b/test/async_utils_test.py index 81b9fd165..b82f1b283 100644 --- a/test/async_utils_test.py +++ b/test/async_utils_test.py @@ -13,6 +13,8 @@ from modal._utils.async_utils import ( TaskContext, aclosing, + async_map, + async_map_ordered, async_merge, async_zip, callable_to_agen, @@ -489,6 +491,52 @@ async def gen2(): assert result == [(1, 3), (2, 4)] +@pytest.mark.asyncio +async def test_async_zip_cancellation(): + ev = asyncio.Event() + + async def gen1(): + await asyncio.sleep(0.1) + yield 1 + await ev.wait() + raise asyncio.CancelledError() + yield 2 + + async def gen2(): + yield 3 + await asyncio.sleep(0.1) + yield 4 + + async def zip_coro(): + async for _ in async_zip(gen1(), gen2()): + pass + + zip_task = asyncio.create_task(zip_coro()) + await asyncio.sleep(0.1) + zip_task.cancel() + with pytest.raises(asyncio.CancelledError): + await zip_task + + +@pytest.mark.asyncio +async def test_async_zip_producer_cancellation(): + async def gen1(): + await asyncio.sleep(0.1) + yield 1 + raise asyncio.CancelledError() + yield 2 + + async def gen2(): + yield 3 + await asyncio.sleep(0.1) + yield 4 + + await asyncio.sleep(0.1) + with pytest.raises(asyncio.CancelledError): + async for _ in async_zip(gen1(), gen2()): + pass + + @pytest.mark.asyncio async def test_async_merge(): result = [] @@ -564,15 +612,15 @@ async def gen2(): await asyncio.sleep(0) states.append("gen2 exit") - async with aclosing(gen1()) as g1, aclosing(gen2()) as g2, aclosing(async_merge(g1, g2)) as stream: + async with aclosing(async_merge(gen1(), gen2())) as stream: async for _ in stream: break - assert states == [ + assert sorted(states) == [ "gen1 enter", + "gen1 exit", "gen2 enter", "gen2 exit", - "gen1 exit", ] @@ -614,6 +662,51 @@ async def gen2(): ] +@pytest.mark.asyncio +async def test_async_merge_cancellation(): + ev = asyncio.Event() + + async def gen1(): + await asyncio.sleep(0.1) + yield 1 + await ev.wait() + yield 2 + + async def gen2(): + yield 3 + await asyncio.sleep(0.1) + yield 4 + + async def merge_coro(): + async for _ in async_merge(gen1(), gen2()): + pass + + merge_task = asyncio.create_task(merge_coro()) + await asyncio.sleep(0.1) + merge_task.cancel() + with pytest.raises(asyncio.CancelledError): + await merge_task + + +@pytest.mark.asyncio +async def test_async_merge_producer_cancellation(): + async def gen1(): + await asyncio.sleep(0.1) + yield 1 + raise asyncio.CancelledError() + yield 2 + + async def gen2(): + yield 3 + await asyncio.sleep(0.1) + yield 4 + + await asyncio.sleep(0.1) + with pytest.raises(asyncio.CancelledError): + async for _ in async_merge(gen1(), gen2()): + pass + + @pytest.mark.asyncio async def test_callable_to_agen(): async def foo(): @@ -624,3 +717,261 @@ async def foo(): async for item in callable_to_agen(foo): result.append(item) assert result == [await foo()] + + +@pytest.mark.asyncio +async def test_async_map(): + result = [] + states = [] + + async def foo(): + states.append("enter") + try: + yield 1 + yield 2 + yield 3 + finally: + states.append("exit") + + async def mapper(x): + await asyncio.sleep(0.1) # Simulate some async work + return x * 2 + + async for item in async_map(foo(), mapper, concurrency=3): + result.append(item) + + assert sorted(result) == [2, 4, 6] + assert states == ["enter", "exit"] + + +@pytest.mark.asyncio +async def test_async_map_input_exception_async_producer(): + # test exception async producer + result = [] + states = [] + + async def mapper_func(x): + await asyncio.sleep(0.1) + return x * 2 + + async def gen(): + states.append("enter") + try: + for i in range(5): + if i == 3: + raise SampleException("test") + yield i + finally: + states.append("exit") + + with pytest.raises(SampleException): + async for item in async_map(gen(), mapper_func, concurrency=3): + result.append(item) + + assert sorted(result) == [] + assert sorted(states) == ["enter", "exit"] + + +@pytest.mark.asyncio +async def test_async_map_input_cancellation_async_producer(): + # test cancelling async_map while waiting for input + result = [] + states = [] + + async def mapper_func(x): + await asyncio.sleep(0.1) + return x * 2 + + async def gen(): + states.append("enter") + try: + for i in range(5): + if i == 3: + raise asyncio.CancelledError() + yield i + finally: + states.append("exit") + + with pytest.raises(asyncio.CancelledError): + async for item in async_map(gen(), mapper_func, concurrency=3): + result.append(item) + + assert sorted(result) == [] + assert sorted(states) == ["enter", "exit"] + + +@pytest.mark.asyncio +async def test_async_map_cancellation_waiting_for_input(): + # test cancelling async_map while waiting for input + result = [] + states = [] + + async def mapper_func(x): + return x * 2 + + blocking_event = asyncio.Event() + + async def gen(): + states.append("enter") + try: + await blocking_event.wait() + yield 1 + finally: + states.append("exit") + + async def mapper_coro(): + async for item in async_map(gen(), mapper_func, concurrency=3): + result.append(item) + + mapper_task = asyncio.create_task(mapper_coro()) + await asyncio.sleep(0.1) + mapper_task.cancel() + with pytest.raises(asyncio.CancelledError): + await mapper_task + + assert sorted(result) == [] + assert sorted(states) == ["enter", "exit"] + + +@pytest.mark.asyncio +async def test_async_map_input_exception_sync_producer(): + # test exception sync producer + result = [] + states = [] + + async def mapper_func(x): + await asyncio.sleep(0.1) + return x * 2 + + def gen(): + states.append("enter") + try: + for i in range(5): + if i == 3: + raise SampleException("test") + yield i + finally: + states.append("exit") + + with pytest.raises(SampleException): + async for item in async_map(gen(), mapper_func, concurrency=3): + result.append(item) + + assert sorted(result) == [] + assert sorted(states) == ["enter", "exit"] + + +@pytest.mark.asyncio +async def test_async_map_output_exception_async_func(): + # test cancelling async mapper function + result = [] + states = [] + + def gen(): + states.append("enter") + try: + for i in range(5): + yield i + finally: + states.append("exit") + + async def mapper_func(x): + await asyncio.sleep(0.1) + if x == 3: + raise SampleException("test") + return x * 2 + + with pytest.raises(SampleException): + async for item in async_map(gen(), mapper_func, concurrency=3): + result.append(item) + + assert sorted(result) == [0, 2, 4] + assert states == ["enter", "exit"] + + +@pytest.mark.asyncio +async def test_async_map_streaming_input(): + # ensure we can stream input + # and dont buffer all the items and return them after + result = [] + states = [] + + async def gen(): + states.append("enter") + try: + yield 1 + await asyncio.sleep(1) + yield 2 + yield 3 + finally: + states.append("exit") + + async def mapper(x): + await asyncio.sleep(0.1) + return x * 2 + + import time + + start = time.time() + async for item in async_map(gen(), mapper, concurrency=3): + if item == 2: + assert time.time() - start < 0.5 + else: + assert time.time() - start > 0.5 + result.append(item) + + assert result == [2, 4, 6] + assert states == ["enter", "exit"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("in_order", [True, False]) +async def test_async_map_concurrency(in_order): + active_mappers = 0 + active_mappers_history = [] + + async def mapper(x): + nonlocal active_mappers + active_mappers += 1 + active_mappers_history.append(active_mappers) + await asyncio.sleep(0.1) # Simulate some async work + active_mappers -= 1 + return x * 2 + + if in_order: + result = [item async for item in async_map_ordered(range(10), mapper, concurrency=3)] + else: + result = [item async for item in async_map(range(10), mapper, concurrency=3)] + assert sorted(result) == [x * 2 for x in range(10)] + assert max(active_mappers_history) == 3 + assert active_mappers_history.count(3) >= 7 # 2, ... 3, 4, 5 and 6, 7, 8 (9 *could* also be active with 3) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("in_order", [True, False]) +async def test_async_map_ordering(in_order): + result = [] + ev = asyncio.Event() + + async def foo(): + yield 1 + yield 2 + yield 3 + + async def mapper(x): + if x == 1: + await ev.wait() + + if x == 2: + ev.set() + + return x * 2 + + if in_order: + async for item in async_map_ordered(foo(), mapper, concurrency=3): + result.append(item) + assert result == [2, 4, 6] + else: + async for item in async_map(foo(), mapper, concurrency=3): + result.append(item) + assert result == [4, 6, 2]