Skip to content

Commit

Permalink
Merge branch 'main' of github.com:modal-labs/modal-client into kramst…
Browse files Browse the repository at this point in the history
…rom/mod-2454-remove-aiostream-dependency-async_map
  • Loading branch information
kramstrom committed Oct 17, 2024
2 parents b644bbe + 5610347 commit c6a41f0
Showing 3 changed files with 49 additions and 10 deletions.
25 changes: 16 additions & 9 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
@@ -88,14 +88,16 @@ def __init__(
input_ids: List[str],
function_call_ids: List[str],
finalized_function: FinalizedFunction,
deserialized_args: List[Any],
function_inputs: List[api_pb2.FunctionInput],
is_batched: bool,
client: _Client,
):
self.input_ids = input_ids
self.function_call_ids = function_call_ids
self.finalized_function = finalized_function
self._deserialized_args = deserialized_args
self._function_inputs = function_inputs
self._is_batched = is_batched
self._client = client

@classmethod
async def create(
@@ -123,9 +125,7 @@ async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -
method_name = function_inputs[0].method_name
assert all(method_name == input.method_name for input in function_inputs)
finalized_function = finalized_functions[method_name]
# TODO(cathy) Performance decrease if we deserialize function_inputs later
deserialized_args = [deserialize(input.args, client) if input.args else ((), {}) for input in function_inputs]
return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched)
return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)

def set_cancel_callback(self, cb: Callable[[], None]):
self._cancel_callback = cb
@@ -141,8 +141,14 @@ def cancel(self):
logger.warning("Unexpected: Could not cancel input")

def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]:
# deserializing here instead of the constructor
# to make sure we handle user exceptions properly
# and don't retry
deserialized_args = [
deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs
]
if not self._is_batched:
return self._deserialized_args[0]
return deserialized_args[0]

func_name = self.finalized_function.callable.__name__

@@ -152,7 +158,8 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]:

# aggregate args and kwargs of all inputs into a kwarg dict
kwargs_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(self.input_ids))]
for i, (args, kwargs) in enumerate(self._deserialized_args):

for i, (args, kwargs) in enumerate(deserialized_args):
# check that all batched inputs should have the same number of args and kwargs
if (num_params := len(args) + len(kwargs)) != len(param_names):
raise InvalidError(
@@ -178,10 +185,10 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List[Any]]]:
return (), formatted_kwargs

def call_finalized_function(self) -> Any:
logger.debug(f"Starting input {self.input_ids}")
args, kwargs = self._args_and_kwargs()
logger.debug(f"Starting input {self.input_ids} (async)")
res = self.finalized_function.callable(*args, **kwargs)
logger.debug(f"Finished input {self.input_ids} (async)")
logger.debug(f"Finished input {self.input_ids}")
return res

def validate_output_data(self, data: Any) -> List[Any]:
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2024

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 189 # git: 8952698
build_number = 190 # git: adbd096
32 changes: 32 additions & 0 deletions test/container_test.py
Original file line number Diff line number Diff line change
@@ -2256,3 +2256,35 @@ def test_set_local_input_concurrency(servicer):
def test_sandbox_infers_app(servicer, event_loop):
_run_container(servicer, "test.supports.sandbox", "spawn_sandbox")
assert servicer.sandbox_app_id == "ap-1"


@skip_github_non_linux
def test_deserialization_error_returns_exception(servicer, client):
inputs = [
api_pb2.FunctionGetInputsResponse(
inputs=[
api_pb2.FunctionGetInputsItem(
input_id="in-xyz0",
function_call_id="fc-123",
input=api_pb2.FunctionInput(
args=b"\x80\x04\x95(\x00\x00\x00\x00\x00\x00\x00\x8c\x17",
data_format=api_pb2.DATA_FORMAT_PICKLE,
method_name="",
),
),
]
),
*_get_inputs(((2,), {})),
]
ret = _run_container(
servicer,
"test.supports.functions",
"square",
inputs=inputs,
)
assert len(ret.items) == 2
assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE
assert "DeserializationError" in ret.items[0].result.exception

assert ret.items[1].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
assert int(deserialize(ret.items[1].result.data, ret.client)) == 4

0 comments on commit c6a41f0

Please sign in to comment.