Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOD-4118: handle deserialization errors #2342

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ def __init__(
input_ids: List[str],
function_call_ids: List[str],
finalized_function: FinalizedFunction,
deserialized_args: List[Any],
function_inputs: List[api_pb2.FunctionInput],
is_batched: bool,
client: _Client,
):
self.input_ids = input_ids
self.function_call_ids = function_call_ids
self.finalized_function = finalized_function
self._deserialized_args = deserialized_args
self._function_inputs = function_inputs
self._is_batched = is_batched
self._client = client

@classmethod
async def create(
Expand Down Expand Up @@ -123,9 +125,7 @@ async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -
method_name = function_inputs[0].method_name
assert all(method_name == input.method_name for input in function_inputs)
finalized_function = finalized_functions[method_name]
# TODO(cathy) Performance decrease if we deserialize function_inputs later
deserialized_args = [deserialize(input.args, client) if input.args else ((), {}) for input in function_inputs]
return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched)
return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)

def set_cancel_callback(self, cb: Callable[[], None]):
self._cancel_callback = cb
Expand All @@ -141,8 +141,14 @@ def cancel(self):
logger.warning("Unexpected: Could not cancel input")

def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]:
# deserializing here instead of the constructor
# to make sure we handle user exceptions properly
# and don't retry
deserialized_args = [
deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs
]
if not self._is_batched:
return self._deserialized_args[0]
return deserialized_args[0]

func_name = self.finalized_function.callable.__name__

Expand All @@ -152,7 +158,8 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]:

# aggregate args and kwargs of all inputs into a kwarg dict
kwargs_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(self.input_ids))]
for i, (args, kwargs) in enumerate(self._deserialized_args):

for i, (args, kwargs) in enumerate(deserialized_args):
# check that all batched inputs should have the same number of args and kwargs
if (num_params := len(args) + len(kwargs)) != len(param_names):
raise InvalidError(
Expand Down
32 changes: 32 additions & 0 deletions test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,3 +2256,35 @@ def test_set_local_input_concurrency(servicer):
def test_sandbox_infers_app(servicer, event_loop):
_run_container(servicer, "test.supports.sandbox", "spawn_sandbox")
assert servicer.sandbox_app_id == "ap-1"


@skip_github_non_linux
def test_deserialization_error_returns_exception(servicer, client):
inputs = [
api_pb2.FunctionGetInputsResponse(
inputs=[
api_pb2.FunctionGetInputsItem(
input_id="in-xyz0",
function_call_id="fc-123",
input=api_pb2.FunctionInput(
args=b"\x80\x04\x95(\x00\x00\x00\x00\x00\x00\x00\x8c\x17",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a byte string here, can you feed it a pickled object? or is that impossible because it needs to be a different package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah would have to somehow pickle the object in another venv than it's being pickled in which sounds annoying but more realistic, but both of these causes serialization errors

data_format=api_pb2.DATA_FORMAT_PICKLE,
method_name="",
),
),
]
),
*_get_inputs(((2,), {})),
]
ret = _run_container(
servicer,
"test.supports.functions",
"square",
inputs=inputs,
)
assert len(ret.items) == 2
assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
assert "DeserializationError" in ret.items[0].result.exception

assert ret.items[1].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
assert int(deserialize(ret.items[1].result.data, ret.client)) == 4
Loading