Skip to content

Commit

Permalink
Add client retries to client
Browse files Browse the repository at this point in the history
  • Loading branch information
rculbertson committed Dec 5, 2024
1 parent 38dc9e3 commit 081faab
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
8 changes: 8 additions & 0 deletions modal/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ class PendingDeprecationError(UserWarning):
"""Soon to be deprecated feature. Only used intermittently because of multi-repo concerns."""


class LostInputsError(Error):
"""Raised when the server reports that it is no longer processing specified inputs."""

def __init__(self, lost_inputs: list[str]):
self.lost_inputs = lost_inputs
super().__init__()


class _CliUserExecutionError(Exception):
"""mdmd:hidden
Private wrapper for exceptions during when importing or running stubs from the CLI.
Expand Down
17 changes: 12 additions & 5 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ExecutionError,
FunctionTimeoutError,
InvalidError,
LostInputsError,
NotFoundError,
OutputExpiredError,
deprecation_warning,
Expand Down Expand Up @@ -179,7 +180,7 @@ async def create(
return _Invocation(client.stub, function_call_id, client, retry_context)

async def pop_function_call_outputs(
self, timeout: Optional[float], clear_on_success: bool
self, timeout: Optional[float], clear_on_success: bool, expected_input_ids: Optional[list[str]] = None
) -> api_pb2.FunctionGetOutputsResponse:
t0 = time.time()
if timeout is None:
Expand All @@ -195,13 +196,17 @@ async def pop_function_call_outputs(
last_entry_id="0-0",
clear_on_success=clear_on_success,
requested_at=time.time(),
expected_input_ids=expected_input_ids,
)
response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
self.stub.FunctionGetOutputs,
request,
attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
)

if response.lost_input_ids:
raise LostInputsError(response.lost_input_ids)

if len(response.outputs) > 0:
return response

Expand All @@ -224,10 +229,12 @@ async def _retry_input(self) -> None:
request,
)

async def _get_single_output(self) -> Any:
async def _get_single_output(self, expected_input_id: Optional[str] = None) -> Any:
# waits indefinitely for a single result for the function, and clear the outputs buffer after
item: api_pb2.FunctionGetOutputsItem = (
await self.pop_function_call_outputs(timeout=None, clear_on_success=True)
await self.pop_function_call_outputs(
timeout=None, clear_on_success=True, expected_input_ids=[expected_input_id]
)
).outputs[0]
return await _process_result(item.result, item.data_format, self.stub, self.client)

Expand All @@ -247,8 +254,8 @@ async def run_function(self) -> Any:

while True:
try:
return await self._get_single_output()
except (UserCodeException, FunctionTimeoutError) as exc:
return await self._get_single_output(self._retry_context.input_id)
except (UserCodeException, FunctionTimeoutError, LostInputsError) as exc:
await user_retry_manager.raise_or_sleep(exc)
await self._retry_input()

Expand Down
9 changes: 9 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(self, blob_host, blobs, credentials):
self.done = False
self.rate_limit_sleep_duration = None
self.fail_get_inputs = False
self.fail_get_outputs_with_lost_inputs = False
self.slow_put_inputs = False
self.container_inputs = []
self.container_outputs = []
Expand Down Expand Up @@ -1090,6 +1091,14 @@ async def FunctionGetOutputs(self, stream):
input_id=input_id, idx=idx, result=result, data_format=api_pb2.DATA_FORMAT_PICKLE
)

if self.fail_get_outputs_with_lost_inputs:
# We fail the output after invoking the input's function because so our tests use the number of function
# invocations to assert the function was retried the correct number of times.
await stream.send_message(
api_pb2.FunctionGetOutputsResponse(num_unfinished_inputs=1, lost_input_ids=[input_id])
)
return

if output_exc:
output = output_exc
else:
Expand Down
17 changes: 17 additions & 0 deletions test/function_retry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import modal
from modal import App
from modal.exception import LostInputsError
from modal.retries import RetryManager
from modal_proto import api_pb2

Expand Down Expand Up @@ -56,7 +57,9 @@ def test_all_retries_fail_raises_error(client, setup_app_and_function, monkeypat
app, f = setup_app_and_function
with app.run(client=client):
with pytest.raises(FunctionCallCountException) as exc_info:
# The client should give up after the 4th call.
f.remote(5)
# Assert the function was called 4 times - the original call plus 3 retries
assert exc_info.value.function_call_count == 4


Expand Down Expand Up @@ -85,3 +88,17 @@ def test_retry_dealy_ms():

retry_policy = api_pb2.FunctionRetryPolicy(retries=2, backoff_coefficient=3, initial_delay_ms=2000)
assert RetryManager._retry_delay_ms(2, retry_policy) == 6000


def test_lost_inputs_retried(client, setup_app_and_function, monkeypatch, servicer):
monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true")
app, f = setup_app_and_function
# This flag forces the fake server always report 1 lost input, and no successful outputs.
servicer.fail_get_outputs_with_lost_inputs = True
with app.run(client=client):
with pytest.raises(LostInputsError):
# The value we pass to the function doesn't matter. The call to GetOutputs will always fail due to lost
# inputs. We use function as a way to track how many times the function was retried.
f.remote(5)
# Assert the function was called 4 times - the original call plus 3 retries
assert function_call_count == 4

0 comments on commit 081faab

Please sign in to comment.