Skip to content

Commit

Permalink
Track in-flight inputs by idx instead of input_id
Browse files Browse the repository at this point in the history
  • Loading branch information
rohansingh committed Dec 2, 2024
1 parent dc920aa commit 55349c4
Showing 1 changed file with 29 additions and 33 deletions.
62 changes: 29 additions & 33 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,11 @@ async def _map_invocation(
num_inputs = 0
num_outputs = 0

io_lock = asyncio.Lock()

def count_update():
if count_update_callback is not None:
count_update_callback(num_outputs, num_inputs)

pending_outputs: dict[str, _MapRetryContext | None] = {} # Map input_id -> retry context
pending_outputs: dict[int, _MapRetryContext | None] = {} # Map input idx -> retry context
completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)

input_queue: TimedPriorityQueue[
Expand Down Expand Up @@ -183,7 +181,7 @@ async def pump_inputs(items):
count_update()
for response_item in resp.inputs:
original_item = items_by_idx[response_item.idx]
pending_outputs[response_item.input_id] = _MapRetryContext(
pending_outputs[response_item.idx] = _MapRetryContext(
input=original_item.input,
input_id=response_item.input_id,
input_jwt=response_item.input_jwt,
Expand Down Expand Up @@ -241,8 +239,7 @@ async def process_input_queue():
raise ValueError(f"Unknown item type on queue: {item}")

if put_items:
async with io_lock:
await pump_inputs(put_items)
await pump_inputs(put_items)

if retry_items:
await pump_retries(retry_items)
Expand Down Expand Up @@ -272,35 +269,34 @@ async def get_all_outputs():
if len(response.outputs) == 0:
continue

async with io_lock:
last_entry_id = response.last_entry_id
for item in response.outputs:
if item.input_id in completed_outputs:
# If this input is already completed, it means the output has already been
# processed and was received again due to a duplicate.
continue

if item.result and item.result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
# If the output is not successful, we need to retry it.
retry_context = pending_outputs[item.input_id]
if retry_context:
delay_ms = retry_context.retry_manager.get_delay_ms()

if delay_ms is not None:
retry = api_pb2.FunctionRetryInputsItem(
input_jwt=retry_context.input_jwt,
input=retry_context.input,
)
await input_queue.put_with_timestamp(time.time() + delay_ms, retry)
continue

completed_outputs.add(item.input_id)
num_outputs += 1
last_entry_id = response.last_entry_id
for item in response.outputs:
if item.input_id in completed_outputs:
# If this input is already completed, it means the output has already been
# processed and was received again due to a duplicate.
continue

if item.result and item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
# clear the retry context to allow it to be garbage collected
pending_outputs[item.input_id] = None

yield item
pending_outputs[item.idx] = None
else:
# If the output is not successful, we need to retry it.
retry_context = pending_outputs[item.idx]
if retry_context:
delay_ms = retry_context.retry_manager.get_delay_ms()

if delay_ms is not None:
retry = api_pb2.FunctionRetryInputsItem(
input_jwt=retry_context.input_jwt,
input=retry_context.input,
)
await input_queue.put_with_timestamp(time.time() + delay_ms, retry)
continue

completed_outputs.add(item.input_id)
num_outputs += 1

yield item

async def get_all_outputs_and_clean_up():
assert client.stub
Expand Down

0 comments on commit 55349c4

Please sign in to comment.