Skip to content

Commit

Permalink
Fix TimedPriorityQueue to work with duplicate timestamps
Browse files Browse the repository at this point in the history
When two items had the same timestamp, we would try to sort by the
actual item value, which breaks for types that don't support comparison.

Instead use a nonce when inserting an item, to ensure that we never have
to compare the item value itself.
  • Loading branch information
rohansingh committed Dec 2, 2024
1 parent ae25593 commit bd1ee57
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
10 changes: 6 additions & 4 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,21 +762,23 @@ async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
raise first_exception


class TimedPriorityQueue(asyncio.PriorityQueue[tuple[float, Union[T, None]]]):
class TimedPriorityQueue(asyncio.PriorityQueue[tuple[float, int, Union[T, None]]]):
"""
A priority queue that schedules items to be processed at specific timestamps.
"""

def __init__(self, maxsize: int = 0):
super().__init__(maxsize=maxsize)
self.condition = asyncio.Condition()
self.nonce = 0

async def put_with_timestamp(self, timestamp: float, item: Union[T, None]):
"""
Add an item to the queue to be processed at a specific timestamp.
"""
async with self.condition:
await super().put((timestamp, item))
self.nonce += 1
await super().put((timestamp, self.nonce, item))
self.condition.notify_all() # notify any waiting coroutines

async def get_next(self) -> Union[T, None]:
Expand All @@ -789,13 +791,13 @@ async def get_next(self) -> Union[T, None]:
await self.condition.wait()

# peek at the next item
timestamp, item = await super().get()
timestamp, nonce, item = await super().get()
now = time.time()

if timestamp > now:
# not ready yet, calculate sleep time
sleep_time = timestamp - now
self.put_nowait((timestamp, item)) # put it back
self.put_nowait((timestamp, nonce, item)) # put it back

# wait until either the timeout or a new item is added
try:
Expand Down
32 changes: 29 additions & 3 deletions test/async_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,11 +1312,12 @@ def line():
@pytest.mark.asyncio
async def test_timed_priority_queue():
queue: async_utils.TimedPriorityQueue[str] = async_utils.TimedPriorityQueue()
now = time.time()

async def producer():
await queue.put_with_timestamp(time.time() + 0.2, "item2")
await queue.put_with_timestamp(time.time() + 0.1, "item1")
await queue.put_with_timestamp(time.time() + 0.3, "item3")
await queue.put_with_timestamp(now + 0.2, "item2")
await queue.put_with_timestamp(now + 0.1, "item1")
await queue.put_with_timestamp(now + 0.3, "item3")

async def consumer():
items = []
Expand All @@ -1329,3 +1330,28 @@ async def consumer():
items = await consumer()

assert items == ["item1", "item2", "item3"]


@pytest.mark.asyncio
async def test_timed_priority_queue_duplicates():
class _QueueItem:
pass

queue: async_utils.TimedPriorityQueue[_QueueItem] = async_utils.TimedPriorityQueue()
now = time.time()

async def producer():
await queue.put_with_timestamp(now + 0.1, _QueueItem())
await queue.put_with_timestamp(now + 0.1, _QueueItem())

async def consumer():
items = []
for _ in range(2):
item = await queue.get_next()
items.append(item)
return items

await producer()
items = await consumer()

assert len([it for it in items if isinstance(it, _QueueItem)]) == 2

0 comments on commit bd1ee57

Please sign in to comment.