Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have client retry lost inputs #2600

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from .._serialization import deserialize, deserialize_data_format, serialize
from .._traceback import append_modal_tb
from ..config import config, logger
from ..exception import DeserializationError, ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
from ..exception import (
DeserializationError,
ExecutionError,
FunctionTimeoutError,
InternalFailure,
InvalidError,
RemoteError,
)
from ..mount import ROOT_DIR, _is_modal_path, _Mount
from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES
Expand Down Expand Up @@ -463,6 +470,8 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,

if result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT:
raise FunctionTimeoutError(result.exception)
elif result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
raise InternalFailure(result.exception)
elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
if data:
try:
Expand Down
6 changes: 6 additions & 0 deletions modal/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ class ServerWarning(UserWarning):
"""Warning originating from the Modal server and re-issued in client code."""


class InternalFailure(Error):
"""
Retriable internal error.
"""


class _CliUserExecutionError(Exception):
"""mdmd:hidden
Private wrapper for exceptions during when importing or running stubs from the CLI.
Expand Down
25 changes: 20 additions & 5 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@
from .client import _Client
from .cloud_bucket_mount import _CloudBucketMount, cloud_bucket_mounts_to_proto
from .config import config
from .exception import ExecutionError, FunctionTimeoutError, InvalidError, NotFoundError, OutputExpiredError
from .exception import (
ExecutionError,
FunctionTimeoutError,
InternalFailure,
InvalidError,
NotFoundError,
OutputExpiredError,
)
from .gpu import GPU_T, parse_gpu_config
from .image import _Image
from .mount import _get_client_mount, _Mount, get_auto_mounts
Expand Down Expand Up @@ -174,7 +181,7 @@ async def create(
return _Invocation(client.stub, function_call_id, client, retry_context)

async def pop_function_call_outputs(
self, timeout: Optional[float], clear_on_success: bool
self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None
) -> api_pb2.FunctionGetOutputsResponse:
t0 = time.time()
if timeout is None:
Expand All @@ -190,6 +197,7 @@ async def pop_function_call_outputs(
last_entry_id="0-0",
clear_on_success=clear_on_success,
requested_at=time.time(),
input_jwts=input_jwts,
)
response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
self.stub.FunctionGetOutputs,
Expand Down Expand Up @@ -219,10 +227,14 @@ async def _retry_input(self) -> None:
request,
)

async def _get_single_output(self) -> Any:
async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any:
# waits indefinitely for a single result for the function, and clear the outputs buffer after
item: api_pb2.FunctionGetOutputsItem = (
await self.pop_function_call_outputs(timeout=None, clear_on_success=True)
await self.pop_function_call_outputs(
timeout=None,
clear_on_success=True,
input_jwts=[expected_jwt] if expected_jwt else None,
)
).outputs[0]
return await _process_result(item.result, item.data_format, self.stub, self.client)

Expand All @@ -242,9 +254,12 @@ async def run_function(self) -> Any:

while True:
try:
return await self._get_single_output()
return await self._get_single_output(ctx.input_jwt)
except (UserCodeException, FunctionTimeoutError) as exc:
await user_retry_manager.raise_or_sleep(exc)
except InternalFailure:
# For system failures on the server, we retry immediately.
pass
await self._retry_input()

async def poll_function(self, timeout: Optional[float] = None):
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(self, blob_host, blobs, credentials):
self.done = False
self.rate_limit_sleep_duration = None
self.fail_get_inputs = False
self.failure_status = api_pb2.GenericResult.GENERIC_STATUS_FAILURE
self.slow_put_inputs = False
self.container_inputs = []
self.container_outputs = []
Expand Down Expand Up @@ -1118,7 +1119,7 @@ async def FunctionGetOutputs(self, stream):
except Exception as exc:
serialized_exc = cloudpickle.dumps(exc)
result = api_pb2.GenericResult(
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
status=self.failure_status,
data=serialized_exc,
exception=repr(exc),
traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
Expand Down
14 changes: 14 additions & 0 deletions test/function_retry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def test_all_retries_fail_raises_error(client, setup_app_and_function, monkeypat
app, f = setup_app_and_function
with app.run(client=client):
with pytest.raises(FunctionCallCountException) as exc_info:
# The client should give up after the 4th call.
f.remote(5)
# Assert the function was called 4 times - the original call plus 3 retries
assert exc_info.value.function_call_count == 4


Expand Down Expand Up @@ -85,3 +87,15 @@ def test_retry_dealy_ms():

retry_policy = api_pb2.FunctionRetryPolicy(retries=2, backoff_coefficient=3, initial_delay_ms=2000)
assert RetryManager._retry_delay_ms(2, retry_policy) == 6000


def test_lost_inputs_retried(client, setup_app_and_function, monkeypatch, servicer):
monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true")
app, f = setup_app_and_function
# The client should retry if it receives a internal failure status.
servicer.failure_status = api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE

with app.run(client=client):
f.remote(10)
# Assert the function was called 10 times
assert function_call_count == 10
Loading