Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental spawn now uses FunctionAsyncInvoke #2627

Open
wants to merge 5 commits into
base: dshaar/spawn-off-async-legacy
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,13 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,


async def _create_input(
args, kwargs, client, *, idx: Optional[int] = None, method_name: Optional[str] = None
args,
kwargs,
client,
*,
idx: Optional[int] = None,
method_name: Optional[str] = None,
force_blob_upload: bool = False,
) -> api_pb2.FunctionPutInputsItem:
"""Serialize function arguments and create a FunctionInput protobuf,
uploading to blob storage if needed.
Expand All @@ -507,7 +513,7 @@ async def _create_input(

args_serialized = serialize((args, kwargs))

if len(args_serialized) > MAX_OBJECT_SIZE_BYTES:
if len(args_serialized) > MAX_OBJECT_SIZE_BYTES or force_blob_upload:
args_blob_id = await blob_upload(args_serialized, client.stub)

return api_pb2.FunctionPutInputsItem(
Expand Down
31 changes: 25 additions & 6 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,14 +1392,33 @@ async def _experimental_spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_Func
self._check_no_web_url("_experimental_spawn")
if self._is_generator:
invocation = await self._call_generator_nowait(args, kwargs)
else:
invocation = await self._call_function_nowait(
args, kwargs, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
fc._is_generator = True
return fc

input = (await _create_input(args, kwargs, self._client, method_name=self._use_method_name)).input
request = api_pb2.FunctionAsyncInvokeRequest(
function_id=self.object_id,
parent_input_id=current_input_id() or "",
input=input,
)
response = await retry_transient_errors(self._client.stub.FunctionAsyncInvoke, request)

# If the server backpressures because the input size is too large, blob upload the input and retry.
if response.retry_with_blob_upload:
input = (
await _create_input(
args, kwargs, self._client, method_name=self._use_method_name, force_blob_upload=True
)
).input
request = api_pb2.FunctionAsyncInvokeRequest(
function_id=self.object_id,
parent_input_id=current_input_id() or "",
input=input,
)
response = await retry_transient_errors(self._client.stub.FunctionAsyncInvoke, request)

fc = _FunctionCall._new_hydrated(invocation.function_call_id, invocation.client, None)
fc._is_generator = self._is_generator if self._is_generator else False
return fc
return _FunctionCall._new_hydrated(response.function_call_id, self._client, None)

@synchronizer.no_input_translation
@live_method
Expand Down
7 changes: 7 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,13 @@ async def EnvironmentGetOrCreate(self, stream):

### Function

async def FunctionAsyncInvoke(self, stream):
self.fcidx += 1
request: api_pb2.FunctionAsyncInvokeRequest = await stream.recv_message()
function_call_id = f"fc-{self.fcidx}"
self.function_id_for_function_call[function_call_id] = request.function_id
await stream.send_message(api_pb2.FunctionAsyncInvokeResponse(function_call_id=function_call_id))

async def FunctionBindParams(self, stream):
from modal._serialization import deserialize

Expand Down
7 changes: 4 additions & 3 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import modal
from modal import App, Image, Mount, NetworkFileSystem, Proxy, asgi_app, batched, web_endpoint
from modal._serialization import deserialize
from modal._utils.async_utils import synchronize_api
from modal._vendor import cloudpickle
from modal.exception import DeprecationError, ExecutionError, InvalidError
Expand Down Expand Up @@ -1009,9 +1010,9 @@ def test_experimental_spawn(client, servicer):
with app.run(client=client):
dummy_modal._experimental_spawn(1, 2)

# Verify the correct invocation type is set
function_map = ctx.pop_request("FunctionMap")
assert function_map.function_call_invocation_type == api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
# Verify the correct input was passed to the function.
request = ctx.pop_request("FunctionAsyncInvoke")
assert deserialize(request.input.args, client) == ((1, 2), {})


def test_from_name_web_url(servicer, set_env_client):
Expand Down
Loading