From 55349c452e1e688435ef6d2b72e949244aca572b Mon Sep 17 00:00:00 2001 From: Rohan Singh Date: Mon, 25 Nov 2024 22:56:43 +0000 Subject: [PATCH] Track in-flight inputs by idx instead of input_id --- modal/parallel_map.py | 62 ++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/modal/parallel_map.py b/modal/parallel_map.py index 15a3570e2..dbc4b8946 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -108,13 +108,11 @@ async def _map_invocation( num_inputs = 0 num_outputs = 0 - io_lock = asyncio.Lock() - def count_update(): if count_update_callback is not None: count_update_callback(num_outputs, num_inputs) - pending_outputs: dict[str, _MapRetryContext | None] = {} # Map input_id -> retry context + pending_outputs: dict[int, _MapRetryContext | None] = {} # Map input idx -> retry context completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values) input_queue: TimedPriorityQueue[ @@ -183,7 +181,7 @@ async def pump_inputs(items): count_update() for response_item in resp.inputs: original_item = items_by_idx[response_item.idx] - pending_outputs[response_item.input_id] = _MapRetryContext( + pending_outputs[response_item.idx] = _MapRetryContext( input=original_item.input, input_id=response_item.input_id, input_jwt=response_item.input_jwt, @@ -241,8 +239,7 @@ async def process_input_queue(): raise ValueError(f"Unknown item type on queue: {item}") if put_items: - async with io_lock: - await pump_inputs(put_items) + await pump_inputs(put_items) if retry_items: await pump_retries(retry_items) @@ -272,35 +269,34 @@ async def get_all_outputs(): if len(response.outputs) == 0: continue - async with io_lock: - last_entry_id = response.last_entry_id - for item in response.outputs: - if item.input_id in completed_outputs: - # If this input is already completed, it means the output has already been - # processed and was received again due to a duplicate. - continue - - if item.result and item.result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: - # If the output is not successful, we need to retry it. - retry_context = pending_outputs[item.input_id] - if retry_context: - delay_ms = retry_context.retry_manager.get_delay_ms() - - if delay_ms is not None: - retry = api_pb2.FunctionRetryInputsItem( - input_jwt=retry_context.input_jwt, - input=retry_context.input, - ) - await input_queue.put_with_timestamp(time.time() + delay_ms, retry) - continue - - completed_outputs.add(item.input_id) - num_outputs += 1 + last_entry_id = response.last_entry_id + for item in response.outputs: + if item.input_id in completed_outputs: + # If this input is already completed, it means the output has already been + # processed and was received again due to a duplicate. + continue + if item.result and item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: # clear the retry context to allow it to be garbage collected - pending_outputs[item.input_id] = None - - yield item + pending_outputs[item.idx] = None + else: + # If the output is not successful, we need to retry it. + retry_context = pending_outputs[item.idx] + if retry_context: + delay_ms = retry_context.retry_manager.get_delay_ms() + + if delay_ms is not None: + retry = api_pb2.FunctionRetryInputsItem( + input_jwt=retry_context.input_jwt, + input=retry_context.input, + ) + await input_queue.put_with_timestamp(time.time() + delay_ms, retry) + continue + + completed_outputs.add(item.input_id) + num_outputs += 1 + + yield item async def get_all_outputs_and_clean_up(): assert client.stub