Skip to content

Commit

Permalink
Remove usage of method placeholder functions (#2364)
Browse files Browse the repository at this point in the history
  • Loading branch information
devennavani authored Nov 27, 2024
1 parent 04ce2fa commit 5d793ef
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 210 deletions.
2 changes: 1 addition & 1 deletion modal/_runtime/container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ async def _dynamic_concurrency_loop(self):
await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)

async def get_app_objects(self) -> RunningApp:
req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True)
req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True, only_class_function=True)
resp = await retry_transient_errors(self._client.stub.AppGetObjects, req)
logger.debug(f"AppGetObjects received {len(resp.items)} objects for app {self.app_id}")

Expand Down
2 changes: 1 addition & 1 deletion modal/cli/import_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_by_object_path(obj: Any, obj_path: str) -> Optional[Any]:
def _infer_function_or_help(
app: App, module, accept_local_entrypoint: bool, accept_webhook: bool
) -> Union[Function, LocalEntrypoint]:
function_choices = set(tag for tag, func in app.registered_functions.items() if not func.info.is_service_class())
function_choices = set(app.registered_functions)
if not accept_webhook:
function_choices -= set(app.registered_web_endpoints)
if accept_local_entrypoint:
Expand Down
21 changes: 17 additions & 4 deletions modal/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,33 @@ def _get_clean_app_description(func_ref: str) -> str:


def _get_click_command_for_function(app: App, function_tag):
function = app.registered_functions[function_tag]
function = app.registered_functions.get(function_tag)
if not function or (isinstance(function, Function) and function.info.user_cls is not None):
# This is either a function_tag for a class method function (e.g MyClass.foo) or a function tag for a
# class service function (MyClass.*)
class_name, method_name = function_tag.rsplit(".", 1)
if not function:
function = app.registered_functions.get(f"{class_name}.*")
assert isinstance(function, Function)
function = typing.cast(Function, function)
if function.is_generator:
raise InvalidError("`modal run` is not supported for generator functions")

signature: Dict[str, ParameterMetadata]
cls: Optional[Cls] = None
method_name: Optional[str] = None
if function.info.user_cls is not None:
class_name, method_name = function_tag.rsplit(".", 1)
cls = typing.cast(Cls, app.registered_classes[class_name])
cls_signature = _get_signature(function.info.user_cls)
fun_signature = _get_signature(function.info.raw_f, is_method=True)
if method_name == "*":
method_names = list(cls._get_partial_functions().keys())
if len(method_names) == 1:
method_name = method_names[0]
else:
class_name = function.info.user_cls.__name__
raise click.UsageError(
f"Please specify a specific method of {class_name} to run, e.g. `modal run foo.py::MyClass.bar`" # noqa: E501
)
fun_signature = _get_signature(getattr(cls, method_name).info.raw_f, is_method=True)
signature = dict(**cls_signature, **fun_signature) # Pool all arguments
# TODO(erikbern): assert there's no overlap?
else:
Expand Down
93 changes: 49 additions & 44 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class _Cls(_Object, type_prefix="cs"):
_class_service_function: Optional[
_Function
] # The _Function serving *all* methods of the class, used for version >=v0.63
_method_functions: Dict[str, _Function] # Placeholder _Functions for each method
_method_functions: Optional[Dict[str, _Function]] = None # Placeholder _Functions for each method
_options: Optional[api_pb2.FunctionOptions]
_callables: Dict[str, Callable[..., Any]]
_from_other_workspace: Optional[bool] # Functions require FunctionBindParams before invocation.
Expand All @@ -253,7 +253,6 @@ class _Cls(_Object, type_prefix="cs"):
def _initialize_from_empty(self):
self._user_cls = None
self._class_service_function = None
self._method_functions = {}
self._options = None
self._callables = {}
self._from_other_workspace = None
Expand All @@ -273,28 +272,46 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]:

def _hydrate_metadata(self, metadata: Message):
assert isinstance(metadata, api_pb2.ClassHandleMetadata)

for method in metadata.methods:
if method.function_name in self._method_functions:
# This happens when the class is loaded locally
# since each function will already be a loaded dependency _Function
self._method_functions[method.function_name]._hydrate(
method.function_id, self._client, method.function_handle_metadata
)
if (
self._class_service_function
and self._class_service_function._method_handle_metadata
and len(self._class_service_function._method_handle_metadata)
):
# The class only has a class service service function and no method placeholders (v0.67+)
if self._method_functions:
# We're here when the Cls is loaded locally (e.g. _Cls.from_local) so the _method_functions mapping is
# populated with (un-hydrated) _Function objects
for (
method_name,
method_handle_metadata,
) in self._class_service_function._method_handle_metadata.items():
self._method_functions[method_name]._hydrate(
self._class_service_function.object_id, self._client, method_handle_metadata
)
else:
# We're here when the function is loaded remotely (e.g. _Cls.from_name)
self._method_functions = {}
for (
method_name,
method_handle_metadata,
) in self._class_service_function._method_handle_metadata.items():
self._method_functions[method_name] = _Function._new_hydrated(
self._class_service_function.object_id, self._client, method_handle_metadata
)
elif self._class_service_function:
# A class with a class service function and method placeholder functions
self._method_functions = {}
for method in metadata.methods:
self._method_functions[method.function_name] = _Function._new_hydrated(
method.function_id, self._client, method.function_handle_metadata
self._class_service_function.object_id, self._client, method.function_handle_metadata
)

def _get_metadata(self) -> api_pb2.ClassHandleMetadata:
class_handle_metadata = api_pb2.ClassHandleMetadata()
for f_name, f in self._method_functions.items():
class_handle_metadata.methods.append(
api_pb2.ClassMethod(
function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata()
else:
# pre 0.63 class that does not have a class service function and only method functions
self._method_functions = {}
for method in metadata.methods:
self._method_functions[method.function_name] = _Function._new_hydrated(
method.function_id, self._client, method.function_handle_metadata
)
)
return class_handle_metadata

@staticmethod
def validate_construction_mechanism(user_cls):
Expand Down Expand Up @@ -327,56 +344,43 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio
# validate signature
_Cls.validate_construction_mechanism(user_cls)

functions: Dict[str, _Function] = {}
method_functions: Dict[str, _Function] = {}
partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls(
user_cls, _PartialFunctionFlags.FUNCTION
)

for method_name, partial_function in partial_functions.items():
method_function = class_service_function._bind_method_old(user_cls, method_name, partial_function)
app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None)
method_function = class_service_function._bind_method(user_cls, method_name, partial_function)
if partial_function.webhook_config is not None:
app._web_endpoints.append(method_function.tag)
partial_function.wrapped = True
functions[method_name] = method_function
method_functions[method_name] = method_function

# Disable the warning that these are not wrapped
for partial_function in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags.FUNCTION).values():
partial_function.wrapped = True

# Get all callables
callables: Dict[str, Callable] = {
k: pf.raw_f for k, pf in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags(0)).items()
k: pf.raw_f for k, pf in _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.all()).items()
}

def _deps() -> List[_Function]:
return [class_service_function] + list(functions.values())
return [class_service_function]

async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[str]):
req = api_pb2.ClassCreateRequest(app_id=resolver.app_id, existing_class_id=existing_object_id)
for f_name, f in self._method_functions.items():
req.methods.append(
api_pb2.ClassMethod(
function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata()
)
)
req = api_pb2.ClassCreateRequest(
app_id=resolver.app_id, existing_class_id=existing_object_id, only_class_function=True
)
resp = await resolver.client.stub.ClassCreate(req)
# Even though we already have the function_handle_metadata for this method locally,
# The RPC is going to replace it with function_handle_metadata derived from the server.
# We need to overwrite the definition_id sent back from the server here with the definition_id
# previously stored in function metadata, which may have been sent back from FunctionCreate.
# The problem is that this metadata propagates back and overwrites the metadata on the Function
# object itself. This is really messy. Maybe better to exclusively populate the method metadata
# from the function metadata we already have locally? Really a lot to clean up here...
for method in resp.handle_metadata.methods:
f_metadata = self._method_functions[method.function_name]._get_metadata()
method.function_handle_metadata.definition_id = f_metadata.definition_id
self._hydrate(resp.class_id, resolver.client, resp.handle_metadata)

rep = f"Cls({user_cls.__name__})"
cls: _Cls = _Cls._from_loader(_load, rep, deps=_deps)
cls._app = app
cls._user_cls = user_cls
cls._class_service_function = class_service_function
cls._method_functions = functions
cls._method_functions = method_functions
cls._callables = callables
cls._from_other_workspace = False
return cls
Expand Down Expand Up @@ -415,6 +419,7 @@ async def _load_remote(obj: _Object, resolver: Resolver, existing_object_id: Opt
environment_name=_environment_name,
lookup_published=workspace is not None,
workspace_name=workspace,
only_class_function=True,
)
try:
response = await retry_transient_errors(resolver.client.stub.ClassGet, request)
Expand Down
102 changes: 4 additions & 98 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def create(
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
) -> "_Invocation":
assert client.stub
function_id = function._invocation_function_id()
function_id = function.object_id
item = await _create_input(args, kwargs, client, method_name=function._use_method_name)

request = api_pb2.FunctionMapRequest(
Expand Down Expand Up @@ -319,8 +319,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
_cluster_size: Optional[int] = None

# when this is the method of a class/object function, invocation of this function
# should be using another function id and supply the method name in the FunctionInput:
_use_function_id: str # The function to invoke
# should supply the method name in the FunctionInput:
_use_method_name: str = ""

_class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None
Expand Down Expand Up @@ -360,94 +359,6 @@ def _bind_method(
fun._is_method = True
return fun

def _bind_method_old(
self,
user_cls,
method_name: str,
partial_function: "modal.partial_function._PartialFunction",
):
"""mdmd:hidden
Creates a function placeholder function that binds a specific method name to
this function for use when invoking the function.
Should only be used on "class service functions". For "instance service functions",
we don't create an actual backend function, and instead do client-side "fake-hydration"
only, see _bind_instance_method.
"""
class_service_function = self
assert class_service_function._info # has to be a local function to be able to "bind" it
assert not class_service_function._is_method # should not be used on an already bound method placeholder
assert not class_service_function._obj # should only be used on base function / class service function
full_name = f"{user_cls.__name__}.{method_name}"
function_type = get_function_type(partial_function.is_generator)

async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]):
function_definition = api_pb2.Function(
function_name=full_name,
webhook_config=partial_function.webhook_config,
function_type=function_type,
is_method=True,
use_function_id=class_service_function.object_id,
use_method_name=method_name,
batch_max_size=partial_function.batch_max_size or 0,
batch_linger_ms=partial_function.batch_wait_ms or 0,
)
assert resolver.app_id
request = api_pb2.FunctionCreateRequest(
app_id=resolver.app_id,
function=function_definition,
# method_bound_function.object_id usually gets set by preload
existing_function_id=existing_object_id or method_bound_function.object_id or "",
defer_updates=True,
)
assert resolver.client.stub is not None # client should be connected when load is called
with FunctionCreationStatus(resolver, full_name) as function_creation_status:
response = await resolver.client.stub.FunctionCreate(request)
method_bound_function._hydrate(
response.function_id,
resolver.client,
response.handle_metadata,
)
function_creation_status.set_response(response)

async def _preload(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]):
if class_service_function._use_method_name:
raise ExecutionError(f"Can't bind method to already bound {class_service_function}")
assert resolver.app_id
req = api_pb2.FunctionPrecreateRequest(
app_id=resolver.app_id,
function_name=full_name,
function_type=function_type,
webhook_config=partial_function.webhook_config,
use_function_id=class_service_function.object_id,
use_method_name=method_name,
existing_function_id=existing_object_id or "",
)
assert resolver.client.stub # client should be connected at this point
response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req)
method_bound_function._hydrate(response.function_id, resolver.client, response.handle_metadata)

def _deps():
return [class_service_function]

rep = f"Method({full_name})"

fun = _Function._from_loader(_load, rep, preload=_preload, deps=_deps)
fun._tag = full_name
fun._raw_f = partial_function.raw_f
fun._info = FunctionInfo(
partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized()
) # needed for .local()
fun._use_method_name = method_name
fun._app = class_service_function._app
fun._is_generator = partial_function.is_generator
fun._cluster_size = partial_function.cluster_size
fun._spec = class_service_function._spec
fun._is_method = True
return fun

def _bind_instance_method(self, class_bound_method: "_Function"):
"""mdmd:hidden
Expand Down Expand Up @@ -475,7 +386,6 @@ def hydrate_from_instance_service_function(method_placeholder_fun):
method_placeholder_fun._is_generator = class_bound_method._is_generator
method_placeholder_fun._cluster_size = class_bound_method._cluster_size
method_placeholder_fun._use_method_name = method_name
method_placeholder_fun._use_function_id = instance_service_function.object_id
method_placeholder_fun._is_method = True

async def _load(fun: "_Function", resolver: Resolver, existing_object_id: Optional[str]):
Expand Down Expand Up @@ -848,6 +758,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
class_serialized=class_serialized or b"",
function_type=function_type,
webhook_config=webhook_config,
method_definitions=method_definitions,
method_definitions_set=True,
shared_volume_mounts=network_file_system_mount_protos(
validated_network_file_systems, allow_cross_region_volumes
),
Expand Down Expand Up @@ -1224,7 +1136,6 @@ def _initialize_from_empty(self):
self._web_url = None
self._function_name = None
self._info = None
self._use_function_id = ""
self._serve_mounts = frozenset()

def _hydrate_metadata(self, metadata: Optional[Message]):
Expand All @@ -1234,15 +1145,11 @@ def _hydrate_metadata(self, metadata: Optional[Message]):
self._web_url = metadata.web_url
self._function_name = metadata.function_name
self._is_method = metadata.is_method
self._use_function_id = metadata.use_function_id
self._use_method_name = metadata.use_method_name
self._class_parameter_info = metadata.class_parameter_info
self._method_handle_metadata = dict(metadata.method_handle_metadata)
self._definition_id = metadata.definition_id

def _invocation_function_id(self) -> str:
return self._use_function_id or self.object_id

def _get_metadata(self):
# Overridden concrete implementation of base class method
assert self._function_name, f"Function name must be set before metadata can be retrieved for {self}"
Expand All @@ -1251,7 +1158,6 @@ def _get_metadata(self):
function_type=get_function_type(self._is_generator),
web_url=self._web_url or "",
use_method_name=self._use_method_name,
use_function_id=self._use_function_id,
is_method=self._is_method,
class_parameter_info=self._class_parameter_info,
definition_id=self._definition_id,
Expand Down
Loading

0 comments on commit 5d793ef

Please sign in to comment.