Skip to content

Commit

Permalink
Fix race condition between inputs/outputs
Browse files Browse the repository at this point in the history
Though very unlikely outside of unit tests, it's possible to have an output
returned before the corresponding retry context has been put into the
`pending_outputs` dict.
  • Loading branch information
rohansingh committed Dec 3, 2024
1 parent c585288 commit 775ad63
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def count_update():
if count_update_callback is not None:
count_update_callback(num_outputs, num_inputs)

pending_outputs: dict[int, _MapItemRetryContext | None] = {} # Map input idx -> retry context
pending_outputs: dict[int, asyncio.Future[_MapItemRetryContext]] = {} # Map input idx -> retry context
completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)

input_queue: TimedPriorityQueue[
Expand Down Expand Up @@ -161,7 +161,11 @@ async def pump_inputs(items):
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
)

event_loop = asyncio.get_event_loop()
items_by_idx = {item.idx: item for item in items}
for item in items:
pending_outputs[item.idx] = event_loop.create_future()

while True:
try:
resp = await retry_transient_errors(
Expand All @@ -185,14 +189,15 @@ async def pump_inputs(items):
for response_item in resp.inputs:
original_item = items_by_idx[response_item.idx]

if response_item.idx not in pending_outputs:
pending_outputs[response_item.idx] = _MapItemRetryContext(
pending_outputs[response_item.idx].set_result(
_MapItemRetryContext(
function_call_invocation_type=function_call_invocation_type,
input=original_item.input,
input_id=response_item.input_id,
input_jwt=response_item.input_jwt,
retry_manager=RetryManager(retry_policy),
)
)
logger.debug(
f"Successfully pushed {len(items)} inputs to server. "
f"Num queued inputs awaiting push is {input_queue.qsize()}."
Expand Down Expand Up @@ -282,12 +287,13 @@ async def get_all_outputs():
# processed and was received again due to a duplicate.
continue

retry_context = await pending_outputs[item.idx]

if item.result and item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
# clear the retry context to allow it to be garbage collected
pending_outputs[item.idx] = None
del pending_outputs[item.idx]
else:
# If the output is not successful, we need to retry it.
retry_context = pending_outputs[item.idx]
if (
retry_context
and retry_context.function_call_invocation_type == api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
Expand Down

0 comments on commit 775ad63

Please sign in to comment.