Skip to content

Commit

Permalink
Add client retry support to .map
Browse files Browse the repository at this point in the history
  • Loading branch information
rohansingh committed Nov 25, 2024
1 parent 315be36 commit 00a2ac2
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 66 deletions.
69 changes: 69 additions & 0 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,72 @@ async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
logger.exception(f"Error closing async generator: {e}")
if first_exception is not None:
raise first_exception


class TimedPriorityQueue(asyncio.PriorityQueue[tuple[float, Union[T, None]]]):
"""
A priority queue that schedules items to be processed at specific timestamps.
"""

def __init__(self, maxsize: int = 0):
super().__init__(maxsize=maxsize)
self.condition = asyncio.Condition()

async def put_with_timestamp(self, timestamp: float, item: Union[T, None]):
"""
Add an item to the queue to be processed at a specific timestamp.
"""
async with self.condition:
await super().put((timestamp, item))
self.condition.notify_all() # notify any waiting coroutines

async def get_next(self) -> Union[T, None]:
"""
Get the next item from the queue that is ready to be processed.
"""
while True:
async with self.condition:
while self.empty():
await self.condition.wait()

# peek at the next item
timestamp, item = await super().get()
now = time.time()

if timestamp > now:
# not ready yet, calculate sleep time
sleep_time = timestamp - now
self.put_nowait((timestamp, item)) # put it back

# wait until either the timeout or a new item is added
try:
await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
except asyncio.TimeoutError:
continue
else:
return item

async def batch(self, max_batch_size=100, debounce_time=0.015) -> AsyncGenerator[list[T], None]:
"""
Read from the queue but return lists of items when queue is large.
Treats a None value as the end of queue items.
"""
batch: list[T] = []
while True:
try:
item: Union[T, None] = await asyncio.wait_for(self.get_next(), timeout=debounce_time)

if item is None:
if batch:
yield batch
return
batch.append(item)

if len(batch) >= max_batch_size:
yield batch
batch = []
except asyncio.TimeoutError:
if batch:
yield batch
batch = []
6 changes: 5 additions & 1 deletion modal/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Modal Labs 2023
import asyncio
import dataclasses
import inspect
import textwrap
Expand Down Expand Up @@ -256,7 +257,10 @@ async def run_function(self) -> Any:
try:
return await self._get_single_output()
except (UserCodeException, FunctionTimeoutError) as exc:
await user_retry_manager.raise_or_sleep(exc)
delay_ms = user_retry_manager.get_delay_ms()
if delay_ms is None:
raise exc
await asyncio.sleep(delay_ms / 1000)
await self._retry_input()

async def poll_function(self, timeout: Optional[float] = None):
Expand Down
Loading

0 comments on commit 00a2ac2

Please sign in to comment.