From a5c77aa08bef120e6a16fca84b74afe0c738c723 Mon Sep 17 00:00:00 2001 From: Ryan Culbertson Date: Tue, 26 Nov 2024 18:34:17 +0000 Subject: [PATCH 1/3] Have client retry lost inputs --- modal/_utils/function_utils.py | 11 ++++++++++- modal/exception.py | 8 ++++++++ modal/functions.py | 25 ++++++++++++++++++++----- test/conftest.py | 3 ++- test/function_retry_test.py | 14 ++++++++++++++ 5 files changed, 54 insertions(+), 7 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 7e60afb1cb..d6b9aacc9b 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -17,7 +17,14 @@ from .._serialization import deserialize, deserialize_data_format, serialize from .._traceback import append_modal_tb from ..config import config, logger -from ..exception import DeserializationError, ExecutionError, FunctionTimeoutError, InvalidError, RemoteError +from ..exception import ( + DeserializationError, + ExecutionError, + FunctionTimeoutError, + InternalFailure, + InvalidError, + RemoteError, +) from ..mount import ROOT_DIR, _is_modal_path, _Mount from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES @@ -463,6 +470,8 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub, if result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT: raise FunctionTimeoutError(result.exception) + elif result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE: + raise InternalFailure(result.exception) elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: if data: try: diff --git a/modal/exception.py b/modal/exception.py index 975c49ecab..b2b97dfcc9 100644 --- a/modal/exception.py +++ b/modal/exception.py @@ -108,6 +108,14 @@ class ServerWarning(UserWarning): """Warning originating from the Modal server and re-issued in client code.""" +class InternalFailure(Error): + """ + Raised when the server returns GENERIC_STATUS_INTERNAL_FAILURE. This is a retiable error which can be + caused by events like 1) redis crashing and the server losing track of inputs, or 2) a worker being + preempted, which terminates the input. + """ + + class _CliUserExecutionError(Exception): """mdmd:hidden Private wrapper for exceptions during when importing or running stubs from the CLI. diff --git a/modal/functions.py b/modal/functions.py index 76a369766a..a814ef35e1 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -59,7 +59,14 @@ from .client import _Client from .cloud_bucket_mount import _CloudBucketMount, cloud_bucket_mounts_to_proto from .config import config -from .exception import ExecutionError, FunctionTimeoutError, InvalidError, NotFoundError, OutputExpiredError +from .exception import ( + ExecutionError, + FunctionTimeoutError, + InternalFailure, + InvalidError, + NotFoundError, + OutputExpiredError, +) from .gpu import GPU_T, parse_gpu_config from .image import _Image from .mount import _get_client_mount, _Mount, get_auto_mounts @@ -174,7 +181,7 @@ async def create( return _Invocation(client.stub, function_call_id, client, retry_context) async def pop_function_call_outputs( - self, timeout: Optional[float], clear_on_success: bool + self, timeout: Optional[float], clear_on_success: bool, expected_jwts: Optional[list[str]] = None ) -> api_pb2.FunctionGetOutputsResponse: t0 = time.time() if timeout is None: @@ -190,6 +197,7 @@ async def pop_function_call_outputs( last_entry_id="0-0", clear_on_success=clear_on_success, requested_at=time.time(), + expected_jwts=expected_jwts, ) response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors( self.stub.FunctionGetOutputs, @@ -219,10 +227,14 @@ async def _retry_input(self) -> None: request, ) - async def _get_single_output(self) -> Any: + async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any: # waits indefinitely for a single result for the function, and clear the outputs buffer after item: api_pb2.FunctionGetOutputsItem = ( - await self.pop_function_call_outputs(timeout=None, clear_on_success=True) + await self.pop_function_call_outputs( + timeout=None, + clear_on_success=True, + expected_jwts=[expected_jwt] if expected_jwt else None, + ) ).outputs[0] return await _process_result(item.result, item.data_format, self.stub, self.client) @@ -242,9 +254,12 @@ async def run_function(self) -> Any: while True: try: - return await self._get_single_output() + return await self._get_single_output(ctx.input_jwt) except (UserCodeException, FunctionTimeoutError) as exc: await user_retry_manager.raise_or_sleep(exc) + except InternalFailure: + # For system failures on the server, we retry immediately. + pass await self._retry_input() async def poll_function(self, timeout: Optional[float] = None): diff --git a/test/conftest.py b/test/conftest.py index b2e98581ea..b9870aae70 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -154,6 +154,7 @@ def __init__(self, blob_host, blobs, credentials): self.done = False self.rate_limit_sleep_duration = None self.fail_get_inputs = False + self.failure_status = api_pb2.GenericResult.GENERIC_STATUS_FAILURE self.slow_put_inputs = False self.container_inputs = [] self.container_outputs = [] @@ -1118,7 +1119,7 @@ async def FunctionGetOutputs(self, stream): except Exception as exc: serialized_exc = cloudpickle.dumps(exc) result = api_pb2.GenericResult( - status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, + status=self.failure_status, data=serialized_exc, exception=repr(exc), traceback="".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), diff --git a/test/function_retry_test.py b/test/function_retry_test.py index e848751eba..3df909611a 100644 --- a/test/function_retry_test.py +++ b/test/function_retry_test.py @@ -56,7 +56,9 @@ def test_all_retries_fail_raises_error(client, setup_app_and_function, monkeypat app, f = setup_app_and_function with app.run(client=client): with pytest.raises(FunctionCallCountException) as exc_info: + # The client should give up after the 4th call. f.remote(5) + # Assert the function was called 4 times - the original call plus 3 retries assert exc_info.value.function_call_count == 4 @@ -85,3 +87,15 @@ def test_retry_dealy_ms(): retry_policy = api_pb2.FunctionRetryPolicy(retries=2, backoff_coefficient=3, initial_delay_ms=2000) assert RetryManager._retry_delay_ms(2, retry_policy) == 6000 + + +def test_lost_inputs_retried(client, setup_app_and_function, monkeypatch, servicer): + monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true") + app, f = setup_app_and_function + # The client should retry if it receives a internal failure status. + servicer.failure_status = api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE + + with app.run(client=client): + f.remote(10) + # Assert the function was called 10 times + assert function_call_count == 10 From 79d9e96197459a1f762fac2526e7a58d0cf377b8 Mon Sep 17 00:00:00 2001 From: Ryan Culbertson Date: Thu, 2 Jan 2025 11:34:55 -0500 Subject: [PATCH 2/3] Update modal/exception.py Co-authored-by: Richard Gong --- modal/exception.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modal/exception.py b/modal/exception.py index b2b97dfcc9..aadc571a9f 100644 --- a/modal/exception.py +++ b/modal/exception.py @@ -110,9 +110,7 @@ class ServerWarning(UserWarning): class InternalFailure(Error): """ - Raised when the server returns GENERIC_STATUS_INTERNAL_FAILURE. This is a retiable error which can be - caused by events like 1) redis crashing and the server losing track of inputs, or 2) a worker being - preempted, which terminates the input. + Retriable internal error. """ From 07dad3a1813ed64d6c36f1ee916eece2758121c6 Mon Sep 17 00:00:00 2001 From: Ryan Culbertson Date: Thu, 2 Jan 2025 16:38:05 +0000 Subject: [PATCH 3/3] Rename expected_jwts to input_jwts --- modal/functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index a814ef35e1..d145c0014d 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -181,7 +181,7 @@ async def create( return _Invocation(client.stub, function_call_id, client, retry_context) async def pop_function_call_outputs( - self, timeout: Optional[float], clear_on_success: bool, expected_jwts: Optional[list[str]] = None + self, timeout: Optional[float], clear_on_success: bool, input_jwts: Optional[list[str]] = None ) -> api_pb2.FunctionGetOutputsResponse: t0 = time.time() if timeout is None: @@ -197,7 +197,7 @@ async def pop_function_call_outputs( last_entry_id="0-0", clear_on_success=clear_on_success, requested_at=time.time(), - expected_jwts=expected_jwts, + input_jwts=input_jwts, ) response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors( self.stub.FunctionGetOutputs, @@ -233,7 +233,7 @@ async def _get_single_output(self, expected_jwt: Optional[str] = None) -> Any: await self.pop_function_call_outputs( timeout=None, clear_on_success=True, - expected_jwts=[expected_jwt] if expected_jwt else None, + input_jwts=[expected_jwt] if expected_jwt else None, ) ).outputs[0] return await _process_result(item.result, item.data_format, self.stub, self.client)