From f77a03545e465559a215079daf9cab594d4ff548 Mon Sep 17 00:00:00 2001 From: kramstrom Date: Wed, 16 Oct 2024 12:49:55 +0200 Subject: [PATCH 1/3] deserialize --- modal/_container_io_manager.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 7140e389c..f4daa3d84 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -88,14 +88,16 @@ def __init__( input_ids: List[str], function_call_ids: List[str], finalized_function: FinalizedFunction, - deserialized_args: List[Any], + function_inputs: List[api_pb2.FunctionInput], is_batched: bool, + client: _Client, ): self.input_ids = input_ids self.function_call_ids = function_call_ids self.finalized_function = finalized_function - self._deserialized_args = deserialized_args + self._function_inputs = function_inputs self._is_batched = is_batched + self._client = client @classmethod async def create( @@ -123,9 +125,7 @@ async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) - method_name = function_inputs[0].method_name assert all(method_name == input.method_name for input in function_inputs) finalized_function = finalized_functions[method_name] - # TODO(cathy) Performance decrease if we deserialize function_inputs later - deserialized_args = [deserialize(input.args, client) if input.args else ((), {}) for input in function_inputs] - return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched) + return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client) def set_cancel_callback(self, cb: Callable[[], None]): self._cancel_callback = cb @@ -141,8 +141,14 @@ def cancel(self): logger.warning("Unexpected: Could not cancel input") def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]: + # deserializing here instead of the constructor + # to make sure we handle user exceptions properly + # and don't retry + deserialized_args = [ + deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs + ] if not self._is_batched: - return self._deserialized_args[0] + return deserialized_args[0] func_name = self.finalized_function.callable.__name__ @@ -152,7 +158,8 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]: # aggregate args and kwargs of all inputs into a kwarg dict kwargs_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(self.input_ids))] - for i, (args, kwargs) in enumerate(self._deserialized_args): + + for i, (args, kwargs) in enumerate(deserialized_args): # check that all batched inputs should have the same number of args and kwargs if (num_params := len(args) + len(kwargs)) != len(param_names): raise InvalidError( From 57b19f9d06bdde646756f5d280d55392a7755110 Mon Sep 17 00:00:00 2001 From: kramstrom Date: Wed, 16 Oct 2024 13:32:31 +0200 Subject: [PATCH 2/3] test --- test/container_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/container_test.py b/test/container_test.py index f01c5f1e4..c4ce220cc 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -2256,3 +2256,35 @@ def test_set_local_input_concurrency(servicer): def test_sandbox_infers_app(servicer, event_loop): _run_container(servicer, "test.supports.sandbox", "spawn_sandbox") assert servicer.sandbox_app_id == "ap-1" + + +@skip_github_non_linux +def test_deserialization_error_returns_exception(servicer, client): + inputs = [ + api_pb2.FunctionGetInputsResponse( + inputs=[ + api_pb2.FunctionGetInputsItem( + input_id="in-xyz0", + function_call_id="fc-123", + input=api_pb2.FunctionInput( + args=b"\x80\x04\x95(\x00\x00\x00\x00\x00\x00\x00\x8c\x17", + data_format=api_pb2.DATA_FORMAT_PICKLE, + method_name="", + ), + ), + ] + ), + *_get_inputs(((2,), {})), + ] + ret = _run_container( + servicer, + "test.supports.functions", + "square", + inputs=inputs, + ) + assert len(ret.items) == 2 + assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE + assert "DeserializationError" in ret.items[0].result.exception + + assert ret.items[1].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + assert int(deserialize(ret.items[1].result.data, ret.client)) == 4 From 226184be7f260a182a79c5694e1ee52d0352caca Mon Sep 17 00:00:00 2001 From: kramstrom Date: Thu, 17 Oct 2024 10:18:16 +0200 Subject: [PATCH 3/3] fix logs --- modal/_container_io_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index f4daa3d84..92f099163 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -185,10 +185,10 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]: return (), formatted_kwargs def call_finalized_function(self) -> Any: + logger.debug(f"Starting input {self.input_ids}") args, kwargs = self._args_and_kwargs() - logger.debug(f"Starting input {self.input_ids} (async)") res = self.finalized_function.callable(*args, **kwargs) - logger.debug(f"Finished input {self.input_ids} (async)") + logger.debug(f"Finished input {self.input_ids}") return res def validate_output_data(self, data: Any) -> List[Any]: