From d4c278245820f8ecc4cd27d255a589404beca88c Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 20 Oct 2024 18:59:26 +0000 Subject: [PATCH 01/96] Miscellaneous cleanup of @app.cls code --- modal/app.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/modal/app.py b/modal/app.py index 75b60c3e4..f438cb128 100644 --- a/modal/app.py +++ b/modal/app.py @@ -884,11 +884,15 @@ def cls( (2024, 5, 1), "interactive=True has been deprecated. Set MODAL_INTERACTIVE_FUNCTIONS=1 instead." ) - if image is None: - image = self._get_default_image() - + image = image or self._get_default_image() secrets = [*self._secrets, *secrets] + scheduler_placement = _experimental_scheduler_placement + if region: + if scheduler_placement: + raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together") + scheduler_placement = SchedulerPlacement(region=region) + def wrapper(user_cls: CLS_T) -> CLS_T: nonlocal keep_warm @@ -896,14 +900,6 @@ def wrapper(user_cls: CLS_T) -> CLS_T: if not inspect.isclass(user_cls): raise TypeError("The @app.cls decorator must be used on a class.") - info = FunctionInfo(None, serialized=serialized, user_cls=user_cls) - - scheduler_placement: Optional[SchedulerPlacement] = _experimental_scheduler_placement - if region: - if scheduler_placement: - raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together") - scheduler_placement = SchedulerPlacement(region=region) - batch_functions = _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.BATCHED) if batch_functions: if len(batch_functions) > 1: @@ -919,6 +915,14 @@ def wrapper(user_cls: CLS_T) -> CLS_T: batch_max_size = None batch_wait_ms = None + if ( + _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.ENTER_PRE_SNAPSHOT) + and not enable_memory_snapshot + ): + raise InvalidError("A class must have `enable_memory_snapshot=True` to use `snap=True` on its methods.") + + info = FunctionInfo(None, serialized=serialized, user_cls=user_cls) + cls_func = _Function.from_args( info, app=self, @@ -947,24 +951,18 @@ def wrapper(user_cls: CLS_T) -> CLS_T: block_network=block_network, max_inputs=max_inputs, scheduler_placement=scheduler_placement, + _experimental_buffer_containers=_experimental_buffer_containers, + _experimental_proxy_ip=_experimental_proxy_ip, # class service function, so the following attributes which relate to # the callable itself are invalid and set to defaults: webhook_config=None, is_generator=False, - _experimental_buffer_containers=_experimental_buffer_containers, - _experimental_proxy_ip=_experimental_proxy_ip, ) self._add_function(cls_func, is_web_endpoint=False) cls: _Cls = _Cls.from_local(user_cls, self, cls_func) - if ( - _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.ENTER_PRE_SNAPSHOT) - and not enable_memory_snapshot - ): - raise InvalidError("A class must have `enable_memory_snapshot=True` to use `snap=True` on its methods.") - tag: str = user_cls.__name__ self._add_object(tag, cls) return cls # type: ignore # a _Cls instance "simulates" being the user provided class From bc850a122ea6520d969b87282c069d24e8542ca5 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 20 Oct 2024 19:23:30 +0000 Subject: [PATCH 02/96] More cleanup --- modal/app.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/modal/app.py b/modal/app.py index f438cb128..3ad89d713 100644 --- a/modal/app.py +++ b/modal/app.py @@ -884,9 +884,6 @@ def cls( (2024, 5, 1), "interactive=True has been deprecated. Set MODAL_INTERACTIVE_FUNCTIONS=1 instead." ) - image = image or self._get_default_image() - secrets = [*self._secrets, *secrets] - scheduler_placement = _experimental_scheduler_placement if region: if scheduler_placement: @@ -926,8 +923,8 @@ def wrapper(user_cls: CLS_T) -> CLS_T: cls_func = _Function.from_args( info, app=self, - image=image, - secrets=secrets, + image=image or self._get_default_image(), + secrets=[*self._secrets, *secrets], gpu=gpu, mounts=[*self._mounts, *mounts], network_file_systems=network_file_systems, From 1511cc96f495f932d108c09e2edecbcae38fe8a5 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 20 Oct 2024 19:30:32 +0000 Subject: [PATCH 03/96] Add comment --- modal/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modal/app.py b/modal/app.py index 3ad89d713..988310798 100644 --- a/modal/app.py +++ b/modal/app.py @@ -879,6 +879,7 @@ def cls( if _warn_parentheses_missing: raise InvalidError("Did you forget parentheses? Suggestion: `@app.cls()`.") + # Argument validation if interactive: deprecation_error( (2024, 5, 1), "interactive=True has been deprecated. Set MODAL_INTERACTIVE_FUNCTIONS=1 instead." From bc52e3806c86eab4259f7d71fcd9a37aa2372ab7 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 20 Oct 2024 20:05:55 +0000 Subject: [PATCH 04/96] More cleanup --- modal/app.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modal/app.py b/modal/app.py index 988310798..a4b6846ab 100644 --- a/modal/app.py +++ b/modal/app.py @@ -951,10 +951,6 @@ def wrapper(user_cls: CLS_T) -> CLS_T: scheduler_placement=scheduler_placement, _experimental_buffer_containers=_experimental_buffer_containers, _experimental_proxy_ip=_experimental_proxy_ip, - # class service function, so the following attributes which relate to - # the callable itself are invalid and set to defaults: - webhook_config=None, - is_generator=False, ) self._add_function(cls_func, is_web_endpoint=False) From a0e5bc0677232699c742ff7e6bb79c17408ae007 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 20 Oct 2024 21:12:45 +0000 Subject: [PATCH 05/96] Add class method webhook configs to definition of class service function --- modal/functions.py | 3 +++ modal_proto/api.proto | 3 +++ 2 files changed, 6 insertions(+) diff --git a/modal/functions.py b/modal/functions.py index 40d32c3e5..9abc41773 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -797,6 +797,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona raise Exception(f"Dependency {dep} isn't hydrated") object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id)) + method_webhook_configs = {} + function_data: Optional[api_pb2.FunctionData] = None function_definition: Optional[api_pb2.Function] = None @@ -812,6 +814,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona class_serialized=class_serialized or b"", function_type=function_type, webhook_config=webhook_config, + method_webhook_configs=method_webhook_configs if info.is_service_class() else None, shared_volume_mounts=network_file_system_mount_protos( validated_network_file_systems, allow_cross_region_volumes ), diff --git a/modal_proto/api.proto b/modal_proto/api.proto index a6960ffc9..d4339ff9c 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -1081,6 +1081,9 @@ message Function { Schedule schedule = 72; + // Need a mapping of method names to webhook configs, e.g. {"method_name": WebhookConfig} + map method_webhook_configs = 73; + } message FunctionBindParamsRequest { From c0ff2177bf81bb601a6447d051b29d094fd2a608 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 25 Oct 2024 18:28:29 +0000 Subject: [PATCH 06/96] wip --- modal/app.py | 9 +++++++++ modal/functions.py | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/modal/app.py b/modal/app.py index d4af4de32..75bec2636 100644 --- a/modal/app.py +++ b/modal/app.py @@ -919,6 +919,14 @@ def wrapper(user_cls: CLS_T) -> CLS_T: ): raise InvalidError("A class must have `enable_memory_snapshot=True` to use `snap=True` on its methods.") + method_web_endpoint_info = {} + partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( + user_cls, _PartialFunctionFlags.FUNCTION + ) + for method_name, partial_function in partial_functions.items(): + web_endpoint_info = api_pb2.WebEndpointInfo(webhook_config=partial_function.webhook_config) + method_web_endpoint_info[method_name] = web_endpoint_info + info = FunctionInfo(None, serialized=serialized, user_cls=user_cls) cls_func = _Function.from_args( @@ -944,6 +952,7 @@ def wrapper(user_cls: CLS_T) -> CLS_T: cpu=cpu, keep_warm=keep_warm, cloud=cloud, + method_web_endpoint_info=method_web_endpoint_info, enable_memory_snapshot=enable_memory_snapshot, checkpointing_enabled=checkpointing_enabled, block_network=block_network, diff --git a/modal/functions.py b/modal/functions.py index 00a33a2ee..170aa6215 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -499,6 +499,7 @@ def from_args( allow_cross_region_volumes: bool = False, volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {}, webhook_config: Optional[api_pb2.WebhookConfig] = None, + method_web_endpoint_info: Optional[Dict[str, api_pb2.WebEndpointInfo]] = None, memory: Optional[Union[int, Tuple[int, int]]] = None, proxy: Optional[_Proxy] = None, retries: Optional[Union[int, Retries]] = None, @@ -537,6 +538,7 @@ def from_args( else: # must be a "class service function" assert info.user_cls + assert method_web_endpoint_info assert not webhook_config assert not schedule @@ -796,8 +798,6 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona raise Exception(f"Dependency {dep} isn't hydrated") object_dependencies.append(api_pb2.ObjectDependency(object_id=dep.object_id)) - method_webhook_configs = {} - function_data: Optional[api_pb2.FunctionData] = None function_definition: Optional[api_pb2.Function] = None @@ -813,7 +813,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona class_serialized=class_serialized or b"", function_type=function_type, webhook_config=webhook_config, - method_webhook_configs=method_webhook_configs if info.is_service_class() else None, + method_web_endpoint_info=method_web_endpoint_info, shared_volume_mounts=network_file_system_mount_protos( validated_network_file_systems, allow_cross_region_volumes ), From 5414a367d1bc5399218aa5dfc5afb986ba998b5b Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 25 Oct 2024 20:44:16 +0000 Subject: [PATCH 07/96] stage --- modal/cls.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index 4610706d8..764a1dc3b 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -318,7 +318,7 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio for method_name, partial_function in partial_functions.items(): method_function = class_service_function._bind_method(user_cls, method_name, partial_function) - app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) + # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) partial_function.wrapped = True functions[method_name] = method_function @@ -332,7 +332,8 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio } def _deps() -> List[_Function]: - return [class_service_function] + list(functions.values()) + # return [class_service_function] + list(functions.values()) + return [class_service_function] async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.ClassCreateRequest(app_id=resolver.app_id, existing_class_id=existing_object_id) From bc6bd79e80c879c78d9a2d182f1b2e7cbc1f0cfa Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 19:05:18 +0000 Subject: [PATCH 08/96] wip --- modal/app.py | 19 +++++-- modal/cls.py | 76 ++++++++++++------------- modal/functions.py | 129 ++++++++++++++++++++++-------------------- modal_proto/api.proto | 32 ++++++----- 4 files changed, 140 insertions(+), 116 deletions(-) diff --git a/modal/app.py b/modal/app.py index 75bec2636..a5c7bf438 100644 --- a/modal/app.py +++ b/modal/app.py @@ -919,13 +919,19 @@ def wrapper(user_cls: CLS_T) -> CLS_T: ): raise InvalidError("A class must have `enable_memory_snapshot=True` to use `snap=True` on its methods.") - method_web_endpoint_info = {} + method_definitions = {} partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( user_cls, _PartialFunctionFlags.FUNCTION ) for method_name, partial_function in partial_functions.items(): - web_endpoint_info = api_pb2.WebEndpointInfo(webhook_config=partial_function.webhook_config) - method_web_endpoint_info[method_name] = web_endpoint_info + if partial_function.is_generator: + function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR + else: + function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + method_definition = api_pb2.MethodDefinition( + webhook_config=partial_function.webhook_config, function_type=function_type + ) + method_definitions[method_name] = method_definition info = FunctionInfo(None, serialized=serialized, user_cls=user_cls) @@ -952,7 +958,7 @@ def wrapper(user_cls: CLS_T) -> CLS_T: cpu=cpu, keep_warm=keep_warm, cloud=cloud, - method_web_endpoint_info=method_web_endpoint_info, + method_definitions=method_definitions, enable_memory_snapshot=enable_memory_snapshot, checkpointing_enabled=checkpointing_enabled, block_network=block_network, @@ -962,6 +968,11 @@ def wrapper(user_cls: CLS_T) -> CLS_T: _experimental_proxy_ip=_experimental_proxy_ip, ) + cls_func._method_functions = {} + for method_name, partial_function in partial_functions.items(): + method_function = cls_func._bind_method(user_cls, method_name, partial_function) + cls_func._method_functions[method_name] = method_function + self._add_function(cls_func, is_web_endpoint=False) cls: _Cls = _Cls.from_local(user_cls, self, cls_func) diff --git a/modal/cls.py b/modal/cls.py index 764a1dc3b..1d4354952 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -259,26 +259,26 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) - for method in metadata.methods: - if method.function_name in self._method_functions: - # This happens when the class is loaded locally - # since each function will already be a loaded dependency _Function - self._method_functions[method.function_name]._hydrate( - method.function_id, self._client, method.function_handle_metadata - ) - else: - self._method_functions[method.function_name] = _Function._new_hydrated( - method.function_id, self._client, method.function_handle_metadata - ) + # for method in metadata.methods: + # if method.function_name in self._method_functions: + # # This happens when the class is loaded locally + # # since each function will already be a loaded dependency _Function + # self._method_functions[method.function_name]._hydrate( + # method.function_id, self._client, method.function_handle_metadata + # ) + # else: + # self._method_functions[method.function_name] = _Function._new_hydrated( + # method.function_id, self._client, method.function_handle_metadata + # ) def _get_metadata(self) -> api_pb2.ClassHandleMetadata: class_handle_metadata = api_pb2.ClassHandleMetadata() - for f_name, f in self._method_functions.items(): - class_handle_metadata.methods.append( - api_pb2.ClassMethod( - function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata() - ) - ) + # for f_name, f in self._method_functions.items(): + # class_handle_metadata.methods.append( + # api_pb2.ClassMethod( + # function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata() + # ) + # ) return class_handle_metadata @staticmethod @@ -311,16 +311,16 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio # validate signature _Cls.validate_construction_mechanism(user_cls) - functions: Dict[str, _Function] = {} - partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( - user_cls, _PartialFunctionFlags.FUNCTION - ) + # functions: Dict[str, _Function] = {} + # partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( + # user_cls, _PartialFunctionFlags.FUNCTION + # ) - for method_name, partial_function in partial_functions.items(): - method_function = class_service_function._bind_method(user_cls, method_name, partial_function) - # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) - partial_function.wrapped = True - functions[method_name] = method_function + # for method_name, partial_function in partial_functions.items(): + # method_function = class_service_function._bind_method(user_cls, method_name, partial_function) + # # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) + # partial_function.wrapped = True + # functions[method_name] = method_function # Disable the warning that these are not wrapped for partial_function in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags.FUNCTION).values(): @@ -337,12 +337,12 @@ def _deps() -> List[_Function]: async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.ClassCreateRequest(app_id=resolver.app_id, existing_class_id=existing_object_id) - for f_name, f in self._method_functions.items(): - req.methods.append( - api_pb2.ClassMethod( - function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata() - ) - ) + # for f_name, f in self._method_functions.items(): + # req.methods.append( + # api_pb2.ClassMethod( + # function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata() + # ) + # ) resp = await resolver.client.stub.ClassCreate(req) # Even though we already have the function_handle_metadata for this method locally, # The RPC is going to replace it with function_handle_metadata derived from the server. @@ -351,9 +351,9 @@ async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[s # The problem is that this metadata propagates back and overwrites the metadata on the Function # object itself. This is really messy. Maybe better to exclusively populate the method metadata # from the function metadata we already have locally? Really a lot to clean up here... - for method in resp.handle_metadata.methods: - f_metadata = self._method_functions[method.function_name]._get_metadata() - method.function_handle_metadata.definition_id = f_metadata.definition_id + # for method in resp.handle_metadata.methods: + # f_metadata = self._method_functions[method.function_name]._get_metadata() + # method.function_handle_metadata.definition_id = f_metadata.definition_id self._hydrate(resp.class_id, resolver.client, resp.handle_metadata) rep = f"Cls({user_cls.__name__})" @@ -361,7 +361,7 @@ async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[s cls._app = app cls._user_cls = user_cls cls._class_service_function = class_service_function - cls._method_functions = functions + # cls._method_functions = functions cls._callables = callables cls._from_other_workspace = False return cls @@ -523,8 +523,8 @@ def __call__(self, *args, **kwargs) -> _Obj: def __getattr__(self, k): # Used by CLI and container entrypoint - if k in self._method_functions: - return self._method_functions[k] + if k in self._class_service_function._method_functions: + return self._class_service_function._method_functions[k] return getattr(self._user_cls, k) diff --git a/modal/functions.py b/modal/functions.py index 170aa6215..9ac3d2f6f 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -325,6 +325,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _parent: Optional["_Function"] = None _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None + _method_functions: Dict[str, "_Function"] # Placeholder _Functions for each method def _bind_method( self, @@ -348,65 +349,61 @@ def _bind_method( assert not class_service_function._obj # should only be used on base function / class service function full_name = f"{user_cls.__name__}.{method_name}" - if partial_function.is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION - - async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): - from ._output import FunctionCreationStatus # Deferred import to avoid Rich dependency in container - - function_definition = api_pb2.Function( - function_name=full_name, - webhook_config=partial_function.webhook_config, - function_type=function_type, - is_method=True, - use_function_id=class_service_function.object_id, - use_method_name=method_name, - batch_max_size=partial_function.batch_max_size or 0, - batch_linger_ms=partial_function.batch_wait_ms or 0, - ) - assert resolver.app_id - request = api_pb2.FunctionCreateRequest( - app_id=resolver.app_id, - function=function_definition, - # method_bound_function.object_id usually gets set by preload - existing_function_id=existing_object_id or method_bound_function.object_id or "", - defer_updates=True, - ) - assert resolver.client.stub is not None # client should be connected when load is called - with FunctionCreationStatus(resolver, full_name) as function_creation_status: - response = await resolver.client.stub.FunctionCreate(request) - method_bound_function._hydrate( - response.function_id, - resolver.client, - response.handle_metadata, - ) - function_creation_status.set_response(response) - - async def _preload(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): - if class_service_function._use_method_name: - raise ExecutionError(f"Can't bind method to already bound {class_service_function}") - assert resolver.app_id - req = api_pb2.FunctionPrecreateRequest( - app_id=resolver.app_id, - function_name=full_name, - function_type=function_type, - webhook_config=partial_function.webhook_config, - use_function_id=class_service_function.object_id, - use_method_name=method_name, - existing_function_id=existing_object_id or "", - ) - assert resolver.client.stub # client should be connected at this point - response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) - method_bound_function._hydrate(response.function_id, resolver.client, response.handle_metadata) - - def _deps(): - return [class_service_function] + # async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): + # from ._output import FunctionCreationStatus # Deferred import to avoid Rich dependency in container + + # function_definition = api_pb2.Function( + # function_name=full_name, + # webhook_config=partial_function.webhook_config, + # function_type=function_type, + # is_method=True, + # use_function_id=class_service_function.object_id, + # use_method_name=method_name, + # batch_max_size=partial_function.batch_max_size or 0, + # batch_linger_ms=partial_function.batch_wait_ms or 0, + # ) + # assert resolver.app_id + # request = api_pb2.FunctionCreateRequest( + # app_id=resolver.app_id, + # function=function_definition, + # # method_bound_function.object_id usually gets set by preload + # existing_function_id=existing_object_id or method_bound_function.object_id or "", + # defer_updates=True, + # ) + # assert resolver.client.stub is not None # client should be connected when load is called + # with FunctionCreationStatus(resolver, full_name) as function_creation_status: + # response = await resolver.client.stub.FunctionCreate(request) + # method_bound_function._hydrate( + # response.function_id, + # resolver.client, + # response.handle_metadata, + # ) + # function_creation_status.set_response(response) + + # async def _preload(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): + # if class_service_function._use_method_name: + # raise ExecutionError(f"Can't bind method to already bound {class_service_function}") + # assert resolver.app_id + # req = api_pb2.FunctionPrecreateRequest( + # app_id=resolver.app_id, + # function_name=full_name, + # function_type=function_type, + # webhook_config=partial_function.webhook_config, + # use_function_id=class_service_function.object_id, + # use_method_name=method_name, + # existing_function_id=existing_object_id or "", + # ) + # assert resolver.client.stub # client should be connected at this point + # response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) + # method_bound_function._hydrate(response.function_id, resolver.client, response.handle_metadata) + + # def _deps(): + # return [class_service_function] rep = f"Method({full_name})" - fun = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) + fun = _Function(rep) + # fun = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) fun._tag = full_name fun._raw_f = partial_function.raw_f fun._info = FunctionInfo( @@ -499,7 +496,7 @@ def from_args( allow_cross_region_volumes: bool = False, volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {}, webhook_config: Optional[api_pb2.WebhookConfig] = None, - method_web_endpoint_info: Optional[Dict[str, api_pb2.WebEndpointInfo]] = None, + method_definitions: Optional[Dict[str, api_pb2.MethodDefinition]] = None, memory: Optional[Union[int, Tuple[int, int]]] = None, proxy: Optional[_Proxy] = None, retries: Optional[Union[int, Retries]] = None, @@ -538,7 +535,7 @@ def from_args( else: # must be a "class service function" assert info.user_cls - assert method_web_endpoint_info + assert method_definitions assert not webhook_config assert not schedule @@ -727,11 +724,20 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti app_id=resolver.app_id, function_name=info.function_name, function_type=function_type, - webhook_config=webhook_config, existing_function_id=existing_object_id or "", ) + if method_definitions: + method_webhook_configs = { + method_name: method_def.webhook_config for method_name, method_def in method_definitions.items() + } + req.method_webhook_configs.update(method_webhook_configs) + else: + req.webhook_config = webhook_config response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) + for method_name, method_function in self._method_functions: + method_handle_metadata = response.handle_metadata.method_handle_metadata[method_name] + method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): from ._output import FunctionCreationStatus # Deferred import to avoid Rich dependency in container @@ -813,7 +819,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona class_serialized=class_serialized or b"", function_type=function_type, webhook_config=webhook_config, - method_web_endpoint_info=method_web_endpoint_info, + method_definitions=method_definitions, shared_volume_mounts=network_file_system_mount_protos( validated_network_file_systems, allow_cross_region_volumes ), @@ -929,6 +935,9 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona function_creation_status.set_response(response) self._hydrate(response.function_id, resolver.client, response.handle_metadata) + for method_name, method_function in self._method_functions: + method_handle_metadata = response.handle_metadata.method_handle_metadata[method_name] + method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 370f5fd76..d1bdc3d87 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -953,6 +953,12 @@ message FileEntry { uint64 size = 4; } +enum FunctionType { + FUNCTION_TYPE_UNSPECIFIED = 0; + FUNCTION_TYPE_GENERATOR = 1; + FUNCTION_TYPE_FUNCTION = 2; +} + message Function { string module_name = 1; string function_name = 2; @@ -967,11 +973,6 @@ message Function { } DefinitionType definition_type = 7; - enum FunctionType { - FUNCTION_TYPE_UNSPECIFIED = 0; - FUNCTION_TYPE_GENERATOR = 1; - FUNCTION_TYPE_FUNCTION = 2; - } FunctionType function_type = 8; Resources resources = 9; @@ -1084,8 +1085,8 @@ message Function { bool snapshot_debug = 73; // For internal debugging use only. - // Need a mapping of method names to web endpoint info - map method_web_endpoint_info = 74; + // Need a mapping of method names to method definitions + map method_definitions = 74; } message FunctionBindParamsRequest { @@ -1177,7 +1178,7 @@ message FunctionData { string module_name = 1; string function_name = 2; - Function.FunctionType function_type = 3; + FunctionType function_type = 3; // Scheduling related fields. uint32 warm_pool_size = 4; @@ -1333,13 +1334,14 @@ message FunctionHandleMetadata { // Should be a subset and use IDs/types from `Function` above string function_name = 2; - Function.FunctionType function_type = 8; + FunctionType function_type = 8; string web_url = 28; bool is_method = 39; string use_function_id = 40; // used for methods string use_method_name = 41; // used for methods string definition_id = 42; ClassParameterInfo class_parameter_info = 43; + map method_handle_metadata = 44; } message FunctionInput { @@ -1386,10 +1388,11 @@ message FunctionPrecreateRequest { string app_id = 1; string function_name = 2 [ (modal.options.audit_target_attr) = true ]; string existing_function_id = 3; - Function.FunctionType function_type = 4; + FunctionType function_type = 4; WebhookConfig webhook_config = 5; string use_function_id = 6; // for class methods - use this function id instead for invocations - the *referenced* function should have is_class=True string use_method_name = 7; // for class methods - this method name needs to be included in the FunctionInput + map method_webhook_configs = 8; } message FunctionPrecreateResponse { @@ -2387,10 +2390,11 @@ message VolumeRemoveFileRequest { bool recursive = 3; } -message WebEndpointInfo { - WebhookConfig webhook_config = 1; - string web_url = 2; - WebUrlInfo web_url_info = 3; +message MethodDefinition { + FunctionType function_type = 1; + WebhookConfig webhook_config = 2; + string web_url = 3; + WebUrlInfo web_url_info = 4; } message WebUrlInfo { From bd90f9f04654b7bdee9a2ef166dd5e67d4c7911c Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 19:25:01 +0000 Subject: [PATCH 09/96] wip --- modal/cls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/cls.py b/modal/cls.py index 1d4354952..e4599ff4d 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -514,7 +514,7 @@ def __call__(self, *args, **kwargs) -> _Obj: return _Obj( self._user_cls, self._class_service_function, - self._method_functions, + self._class_service_function._method_functions, self._from_other_workspace, self._options, args, From b3aa81dde26dd877c3541cb8fafebd739673f9f4 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 19:41:10 +0000 Subject: [PATCH 10/96] fix bugs --- modal/_container_entrypoint.py | 2 +- modal/app.py | 4 ++-- modal/functions.py | 20 +++++++++++--------- test/container_test.py | 18 +++++++++--------- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 2de24d702..b3230fd48 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -137,7 +137,7 @@ def get_finalized_functions( # Check this property before we turn it into a method (overriden by webhooks) is_async = get_is_async(self._user_defined_callable) # Use the function definition for whether this is a generator (overriden by webhooks) - is_generator = fun_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR + is_generator = fun_def.function_type == api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR webhook_config = fun_def.webhook_config if not webhook_config.type: diff --git a/modal/app.py b/modal/app.py index a5c7bf438..edd8f73b9 100644 --- a/modal/app.py +++ b/modal/app.py @@ -925,9 +925,9 @@ def wrapper(user_cls: CLS_T) -> CLS_T: ) for method_name, partial_function in partial_functions.items(): if partial_function.is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR + function_type = api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION method_definition = api_pb2.MethodDefinition( webhook_config=partial_function.webhook_config, function_type=function_type ) diff --git a/modal/functions.py b/modal/functions.py index 9ac3d2f6f..b47f6a943 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -400,9 +400,11 @@ def _bind_method( # def _deps(): # return [class_service_function] - rep = f"Method({full_name})" + async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): + pass - fun = _Function(rep) + rep = f"Method({full_name})" + fun = _Function._from_loader(_load, rep) # fun = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) fun._tag = full_name fun._raw_f = partial_function.raw_f @@ -715,9 +717,9 @@ def _deps(only_explicit_mounts=False) -> List[_Object]: async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub if is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR + function_type = api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION assert resolver.app_id req = api_pb2.FunctionPrecreateRequest( @@ -745,9 +747,9 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona assert resolver.client and resolver.client.stub with FunctionCreationStatus(resolver, tag) as function_creation_status: if is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR + function_type = api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION timeout_secs = timeout @@ -1184,7 +1186,7 @@ def _initialize_from_empty(self): def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata) - self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR + self._is_generator = metadata.function_type == api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR self._web_url = metadata.web_url self._function_name = metadata.function_name self._is_method = metadata.is_method @@ -1202,9 +1204,9 @@ def _get_metadata(self): return api_pb2.FunctionHandleMetadata( function_name=self._function_name, function_type=( - api_pb2.Function.FUNCTION_TYPE_GENERATOR + api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR if self._is_generator - else api_pb2.Function.FUNCTION_TYPE_FUNCTION + else api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION ), web_url=self._web_url or "", use_method_name=self._use_method_name, diff --git a/test/container_test.py b/test/container_test.py index 4d207257a..78ab539f1 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -167,7 +167,7 @@ def _get_multi_inputs_with_methods(args: List[Tuple[str, Tuple, Dict]] = []) -> def _container_args( module_name, function_name, - function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION, webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED, definition_type=api_pb2.Function.DEFINITION_TYPE_FILE, app_name: str = "", @@ -239,7 +239,7 @@ def _run_container( function_name, fail_get_inputs=False, inputs=None, - function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION, webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED, definition_type=api_pb2.Function.DEFINITION_TYPE_FILE, app_name: str = "", @@ -420,7 +420,7 @@ def test_generator_success(servicer, event_loop): servicer, "test.supports.functions", "gen_n", - function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, ) items, exc = _unwrap_generator(ret) @@ -435,7 +435,7 @@ def test_generator_failure(servicer, capsys): servicer, "test.supports.functions", "gen_n_fail_on_m", - function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, inputs=inputs, ) items, exc = _unwrap_generator(ret) @@ -643,7 +643,7 @@ def test_function_returning_generator(servicer): servicer, "test.supports.functions", "fun_returning_gen", - function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, ) items, exc = _unwrap_generator(ret) assert len(items) == 42 @@ -851,7 +851,7 @@ def test_webhook_streaming_sync(servicer): "webhook_streaming", inputs=inputs, webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION, - function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, ) data = _unwrap_asgi(ret) bodies = [d["body"].decode() for d in data if d.get("body")] @@ -868,7 +868,7 @@ def test_webhook_streaming_async(servicer): "webhook_streaming_async", inputs=inputs, webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION, - function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, ) data = _unwrap_asgi(ret) @@ -1022,7 +1022,7 @@ def test_cls_generator(servicer): servicer, "test.supports.functions", "Cls.*", - function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, is_class=True, inputs=_get_inputs(method_name="generator"), ) @@ -1077,7 +1077,7 @@ def test_cli(servicer, credentials): function_def = api_pb2.Function( module_name="test.supports.functions", function_name="square", - function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION, + function_type=api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION, definition_type=api_pb2.Function.DEFINITION_TYPE_FILE, object_dependencies=[api_pb2.ObjectDependency(object_id="im-123")], ) From 033dff727c53d41f7f5a5c842c613d33d331a9e4 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 20:04:00 +0000 Subject: [PATCH 11/96] bug fix --- modal/functions.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index b47f6a943..5af0b01f9 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -729,15 +729,13 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti existing_function_id=existing_object_id or "", ) if method_definitions: - method_webhook_configs = { - method_name: method_def.webhook_config for method_name, method_def in method_definitions.items() - } - req.method_webhook_configs.update(method_webhook_configs) + for method_name, method_def in method_definitions.items(): + req.method_webhook_configs[method_name].CopyFrom(method_def.webhook_config) else: req.webhook_config = webhook_config response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) - for method_name, method_function in self._method_functions: + for method_name, method_function in self._method_functions.items(): method_handle_metadata = response.handle_metadata.method_handle_metadata[method_name] method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) @@ -937,7 +935,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona function_creation_status.set_response(response) self._hydrate(response.function_id, resolver.client, response.handle_metadata) - for method_name, method_function in self._method_functions: + for method_name, method_function in self._method_functions.items(): method_handle_metadata = response.handle_metadata.method_handle_metadata[method_name] method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) From 283c958c44dd1fdaed197246ca44511b1a37e048 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 20:23:03 +0000 Subject: [PATCH 12/96] FunctionType fixes --- modal/_container_entrypoint.py | 2 +- modal/app.py | 8 ++++---- modal/functions.py | 14 +++++++------- modal_proto/api.proto | 20 ++++++++++---------- test/container_test.py | 18 +++++++++--------- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index b3230fd48..2de24d702 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -137,7 +137,7 @@ def get_finalized_functions( # Check this property before we turn it into a method (overriden by webhooks) is_async = get_is_async(self._user_defined_callable) # Use the function definition for whether this is a generator (overriden by webhooks) - is_generator = fun_def.function_type == api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR + is_generator = fun_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR webhook_config = fun_def.webhook_config if not webhook_config.type: diff --git a/modal/app.py b/modal/app.py index edd8f73b9..71981764e 100644 --- a/modal/app.py +++ b/modal/app.py @@ -919,15 +919,15 @@ def wrapper(user_cls: CLS_T) -> CLS_T: ): raise InvalidError("A class must have `enable_memory_snapshot=True` to use `snap=True` on its methods.") - method_definitions = {} + method_definitions: Dict[str, api_pb2.MethodDefinition] = {} partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( user_cls, _PartialFunctionFlags.FUNCTION ) for method_name, partial_function in partial_functions.items(): if partial_function.is_generator: - function_type = api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR + function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR else: - function_type = api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION + function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION method_definition = api_pb2.MethodDefinition( webhook_config=partial_function.webhook_config, function_type=function_type ) @@ -968,7 +968,7 @@ def wrapper(user_cls: CLS_T) -> CLS_T: _experimental_proxy_ip=_experimental_proxy_ip, ) - cls_func._method_functions = {} + cls_func._method_functions: Dict[str, _Function] = {} for method_name, partial_function in partial_functions.items(): method_function = cls_func._bind_method(user_cls, method_name, partial_function) cls_func._method_functions[method_name] = method_function diff --git a/modal/functions.py b/modal/functions.py index 5af0b01f9..38a4c2dd4 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -717,9 +717,9 @@ def _deps(only_explicit_mounts=False) -> List[_Object]: async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub if is_generator: - function_type = api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR + function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR else: - function_type = api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION + function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION assert resolver.app_id req = api_pb2.FunctionPrecreateRequest( @@ -745,9 +745,9 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona assert resolver.client and resolver.client.stub with FunctionCreationStatus(resolver, tag) as function_creation_status: if is_generator: - function_type = api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR + function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR else: - function_type = api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION + function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION timeout_secs = timeout @@ -1184,7 +1184,7 @@ def _initialize_from_empty(self): def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata) - self._is_generator = metadata.function_type == api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR + self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR self._web_url = metadata.web_url self._function_name = metadata.function_name self._is_method = metadata.is_method @@ -1202,9 +1202,9 @@ def _get_metadata(self): return api_pb2.FunctionHandleMetadata( function_name=self._function_name, function_type=( - api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR + api_pb2.Function.FUNCTION_TYPE_GENERATOR if self._is_generator - else api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION + else api_pb2.Function.FUNCTION_TYPE_FUNCTION ), web_url=self._web_url or "", use_method_name=self._use_method_name, diff --git a/modal_proto/api.proto b/modal_proto/api.proto index d1bdc3d87..3290afe9f 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -953,12 +953,6 @@ message FileEntry { uint64 size = 4; } -enum FunctionType { - FUNCTION_TYPE_UNSPECIFIED = 0; - FUNCTION_TYPE_GENERATOR = 1; - FUNCTION_TYPE_FUNCTION = 2; -} - message Function { string module_name = 1; string function_name = 2; @@ -973,6 +967,12 @@ message Function { } DefinitionType definition_type = 7; + enum FunctionType { + FUNCTION_TYPE_UNSPECIFIED = 0; + FUNCTION_TYPE_GENERATOR = 1; + FUNCTION_TYPE_FUNCTION = 2; + } + FunctionType function_type = 8; Resources resources = 9; @@ -1178,7 +1178,7 @@ message FunctionData { string module_name = 1; string function_name = 2; - FunctionType function_type = 3; + Function.FunctionType function_type = 3; // Scheduling related fields. uint32 warm_pool_size = 4; @@ -1334,7 +1334,7 @@ message FunctionHandleMetadata { // Should be a subset and use IDs/types from `Function` above string function_name = 2; - FunctionType function_type = 8; + Function.FunctionType function_type = 8; string web_url = 28; bool is_method = 39; string use_function_id = 40; // used for methods @@ -1388,7 +1388,7 @@ message FunctionPrecreateRequest { string app_id = 1; string function_name = 2 [ (modal.options.audit_target_attr) = true ]; string existing_function_id = 3; - FunctionType function_type = 4; + Function.FunctionType function_type = 4; WebhookConfig webhook_config = 5; string use_function_id = 6; // for class methods - use this function id instead for invocations - the *referenced* function should have is_class=True string use_method_name = 7; // for class methods - this method name needs to be included in the FunctionInput @@ -2391,7 +2391,7 @@ message VolumeRemoveFileRequest { } message MethodDefinition { - FunctionType function_type = 1; + Function.FunctionType function_type = 1; WebhookConfig webhook_config = 2; string web_url = 3; WebUrlInfo web_url_info = 4; diff --git a/test/container_test.py b/test/container_test.py index 78ab539f1..4d207257a 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -167,7 +167,7 @@ def _get_multi_inputs_with_methods(args: List[Tuple[str, Tuple, Dict]] = []) -> def _container_args( module_name, function_name, - function_type=api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION, + function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION, webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED, definition_type=api_pb2.Function.DEFINITION_TYPE_FILE, app_name: str = "", @@ -239,7 +239,7 @@ def _run_container( function_name, fail_get_inputs=False, inputs=None, - function_type=api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION, + function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION, webhook_type=api_pb2.WEBHOOK_TYPE_UNSPECIFIED, definition_type=api_pb2.Function.DEFINITION_TYPE_FILE, app_name: str = "", @@ -420,7 +420,7 @@ def test_generator_success(servicer, event_loop): servicer, "test.supports.functions", "gen_n", - function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, ) items, exc = _unwrap_generator(ret) @@ -435,7 +435,7 @@ def test_generator_failure(servicer, capsys): servicer, "test.supports.functions", "gen_n_fail_on_m", - function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, inputs=inputs, ) items, exc = _unwrap_generator(ret) @@ -643,7 +643,7 @@ def test_function_returning_generator(servicer): servicer, "test.supports.functions", "fun_returning_gen", - function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, ) items, exc = _unwrap_generator(ret) assert len(items) == 42 @@ -851,7 +851,7 @@ def test_webhook_streaming_sync(servicer): "webhook_streaming", inputs=inputs, webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION, - function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, ) data = _unwrap_asgi(ret) bodies = [d["body"].decode() for d in data if d.get("body")] @@ -868,7 +868,7 @@ def test_webhook_streaming_async(servicer): "webhook_streaming_async", inputs=inputs, webhook_type=api_pb2.WEBHOOK_TYPE_FUNCTION, - function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, ) data = _unwrap_asgi(ret) @@ -1022,7 +1022,7 @@ def test_cls_generator(servicer): servicer, "test.supports.functions", "Cls.*", - function_type=api_pb2.FunctionType.FUNCTION_TYPE_GENERATOR, + function_type=api_pb2.Function.FUNCTION_TYPE_GENERATOR, is_class=True, inputs=_get_inputs(method_name="generator"), ) @@ -1077,7 +1077,7 @@ def test_cli(servicer, credentials): function_def = api_pb2.Function( module_name="test.supports.functions", function_name="square", - function_type=api_pb2.FunctionType.FUNCTION_TYPE_FUNCTION, + function_type=api_pb2.Function.FUNCTION_TYPE_FUNCTION, definition_type=api_pb2.Function.DEFINITION_TYPE_FILE, object_dependencies=[api_pb2.ObjectDependency(object_id="im-123")], ) From f0de58e6c9a51c5d21eff711e904541ee59c08a2 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 20:54:56 +0000 Subject: [PATCH 13/96] code cleanup --- modal/_utils/function_utils.py | 4 ++++ modal/app.py | 20 ------------------- modal/functions.py | 36 ++++++++++++++++++++++++++-------- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 329645872..c0d2921b5 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -93,6 +93,10 @@ def is_async(function): raise RuntimeError(f"Function {function} is a strange type {type(function)}") +def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType: + return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION + + class FunctionInfo: """Class that helps us extract a bunch of information about a function.""" diff --git a/modal/app.py b/modal/app.py index 71981764e..d4af4de32 100644 --- a/modal/app.py +++ b/modal/app.py @@ -919,20 +919,6 @@ def wrapper(user_cls: CLS_T) -> CLS_T: ): raise InvalidError("A class must have `enable_memory_snapshot=True` to use `snap=True` on its methods.") - method_definitions: Dict[str, api_pb2.MethodDefinition] = {} - partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( - user_cls, _PartialFunctionFlags.FUNCTION - ) - for method_name, partial_function in partial_functions.items(): - if partial_function.is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION - method_definition = api_pb2.MethodDefinition( - webhook_config=partial_function.webhook_config, function_type=function_type - ) - method_definitions[method_name] = method_definition - info = FunctionInfo(None, serialized=serialized, user_cls=user_cls) cls_func = _Function.from_args( @@ -958,7 +944,6 @@ def wrapper(user_cls: CLS_T) -> CLS_T: cpu=cpu, keep_warm=keep_warm, cloud=cloud, - method_definitions=method_definitions, enable_memory_snapshot=enable_memory_snapshot, checkpointing_enabled=checkpointing_enabled, block_network=block_network, @@ -968,11 +953,6 @@ def wrapper(user_cls: CLS_T) -> CLS_T: _experimental_proxy_ip=_experimental_proxy_ip, ) - cls_func._method_functions: Dict[str, _Function] = {} - for method_name, partial_function in partial_functions.items(): - method_function = cls_func._bind_method(user_cls, method_name, partial_function) - cls_func._method_functions[method_name] = method_function - self._add_function(cls_func, is_web_endpoint=False) cls: _Cls = _Cls.from_local(user_cls, self, cls_func) diff --git a/modal/functions.py b/modal/functions.py index 38a4c2dd4..73db41ae5 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -51,6 +51,7 @@ _create_input, _process_result, _stream_function_call_data, + get_function_type, is_async, ) from ._utils.grpc_utils import retry_transient_errors @@ -498,7 +499,6 @@ def from_args( allow_cross_region_volumes: bool = False, volumes: Dict[Union[str, PurePosixPath], Union[_Volume, _CloudBucketMount]] = {}, webhook_config: Optional[api_pb2.WebhookConfig] = None, - method_definitions: Optional[Dict[str, api_pb2.MethodDefinition]] = None, memory: Optional[Union[int, Tuple[int, int]]] = None, proxy: Optional[_Proxy] = None, retries: Optional[Union[int, Retries]] = None, @@ -525,6 +525,9 @@ def from_args( _experimental_proxy_ip: Optional[str] = None, ) -> None: """mdmd:hidden""" + # Needed to avoid circular imports + from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags + tag = info.get_tag() if info.raw_f: @@ -537,7 +540,6 @@ def from_args( else: # must be a "class service function" assert info.user_cls - assert method_definitions assert not webhook_config assert not schedule @@ -601,8 +603,8 @@ def from_args( ) if info.user_cls and not is_auto_snapshot: - # Needed to avoid circular imports - from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags + # # Needed to avoid circular imports + # from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items() for k, pf in build_functions: @@ -687,6 +689,18 @@ def from_args( if image is not None and not isinstance(image, _Image): raise InvalidError(f"Expected modal.Image object. Got {type(image)}.") + method_definitions: Optional[Dict[str, api_pb2.MethodDefinition]] = None + if info.user_cls: + partial_functions: Dict[ + str, "modal.partial_function._PartialFunction" + ] = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION) + for method_name, partial_function in partial_functions.items(): + function_type = get_function_type(partial_function.is_generator) + method_definition = api_pb2.MethodDefinition( + webhook_config=partial_function.webhook_config, function_type=function_type + ) + method_definitions[method_name] = method_definition + def _deps(only_explicit_mounts=False) -> List[_Object]: deps: List[_Object] = list(secrets) if only_explicit_mounts: @@ -716,10 +730,7 @@ def _deps(only_explicit_mounts=False) -> List[_Object]: async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub - if is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = get_function_type(is_generator) assert resolver.app_id req = api_pb2.FunctionPrecreateRequest( @@ -952,6 +963,15 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._is_method = False obj._spec = function_spec # needed for modal shell + if info.user_cls: + partial_functions: Dict[ + str, "modal.partial_function._PartialFunction" + ] = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION) + obj._method_functions = {} + for method_name, partial_function in partial_functions.items(): + method_function = obj._bind_method(info.user_cls, method_name, partial_function) + obj._method_functions[method_name] = method_function + # Used to check whether we should rebuild a modal.Image which uses `run_function`. gpus: List[GPU_T] = gpu if isinstance(gpu, list) else [gpu] obj._build_args = dict( # See get_build_def From 08afd9a9ba35ca3e0b822db01e1b35ebb48ae353 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 27 Oct 2024 23:52:11 +0000 Subject: [PATCH 14/96] protoc changes --- modal/cls.py | 18 +++++++++--------- modal/functions.py | 41 +++++++++++++++++++++++++++-------------- modal_proto/api.proto | 11 ++++++----- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index e4599ff4d..4d2a5d428 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -312,15 +312,15 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio _Cls.validate_construction_mechanism(user_cls) # functions: Dict[str, _Function] = {} - # partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( - # user_cls, _PartialFunctionFlags.FUNCTION - # ) - - # for method_name, partial_function in partial_functions.items(): - # method_function = class_service_function._bind_method(user_cls, method_name, partial_function) - # # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) - # partial_function.wrapped = True - # functions[method_name] = method_function + partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( + user_cls, _PartialFunctionFlags.FUNCTION + ) + + for partial_function in partial_functions.values(): + # method_function = class_service_function._bind_method(user_cls, method_name, partial_function) + # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) + partial_function.wrapped = True + # functions[method_name] = method_function # Disable the warning that these are not wrapped for partial_function in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags.FUNCTION).values(): diff --git a/modal/functions.py b/modal/functions.py index 73db41ae5..dd221f1c8 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -318,8 +318,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type # when this is the method of a class/object function, invocation of this function # should be using another function id and supply the method name in the FunctionInput: - _use_function_id: str # The function to invoke - _use_method_name: str = "" + # _use_function_id: str # The function to invoke + # _use_method_name: str = "" # TODO (elias): remove _parent. In case of instance functions, and methods bound on those, # this references the parent class-function and is used to infer the client for lazy-loaded methods @@ -691,13 +691,17 @@ def from_args( method_definitions: Optional[Dict[str, api_pb2.MethodDefinition]] = None if info.user_cls: + method_definitions = {} partial_functions: Dict[ str, "modal.partial_function._PartialFunction" ] = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION) for method_name, partial_function in partial_functions.items(): function_type = get_function_type(partial_function.is_generator) + function_name = f"{info.user_cls.__name__}.{method_name}" method_definition = api_pb2.MethodDefinition( - webhook_config=partial_function.webhook_config, function_type=function_type + webhook_config=partial_function.webhook_config, + function_type=function_type, + function_name=function_name, ) method_definitions[method_name] = method_definition @@ -740,8 +744,9 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti existing_function_id=existing_object_id or "", ) if method_definitions: - for method_name, method_def in method_definitions.items(): - req.method_webhook_configs[method_name].CopyFrom(method_def.webhook_config) + req.method_definitions = method_definitions + # for method_name, method_def in method_definitions.items(): + # req.method_webhook_configs[method_name].CopyFrom(method_def.webhook_config) else: req.webhook_config = webhook_config response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) @@ -955,6 +960,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._raw_f = info.raw_f obj._info = info + obj._function_name = info.function_name obj._tag = tag obj._all_mounts = all_mounts # needed for modal.serve file watching obj._app = app # needed for CLI right now @@ -1208,8 +1214,8 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._web_url = metadata.web_url self._function_name = metadata.function_name self._is_method = metadata.is_method - self._use_function_id = metadata.use_function_id - self._use_method_name = metadata.use_method_name + # self._use_function_id = metadata.use_function_id + # self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info self._definition_id = metadata.definition_id @@ -1219,19 +1225,26 @@ def _invocation_function_id(self) -> str: def _get_metadata(self): # Overridden concrete implementation of base class method assert self._function_name, f"Function name must be set before metadata can be retrieved for {self}" + return api_pb2.FunctionHandleMetadata( function_name=self._function_name, - function_type=( - api_pb2.Function.FUNCTION_TYPE_GENERATOR - if self._is_generator - else api_pb2.Function.FUNCTION_TYPE_FUNCTION - ), + function_type=get_function_type(self._is_generator), web_url=self._web_url or "", - use_method_name=self._use_method_name, - use_function_id=self._use_function_id, + # use_method_name=self._use_method_name, + # use_function_id=self._use_function_id, is_method=self._is_method, class_parameter_info=self._class_parameter_info, definition_id=self._definition_id, + method_handle_metadata=[ + api_pb2.FunctionHandleMetadata( + function_name=method_function._function_name, + function_type=get_function_type(method_function._is_generator), + web_url=method_function._web_url or "", + is_method=method_function._is_method, + definition_id=method_function._definition_id, + ) + for method_function in self._method_functions.values() + ], ) def _check_no_web_url(self, fn_name: str): diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 3290afe9f..1d52ae6f0 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -1392,7 +1392,7 @@ message FunctionPrecreateRequest { WebhookConfig webhook_config = 5; string use_function_id = 6; // for class methods - use this function id instead for invocations - the *referenced* function should have is_class=True string use_method_name = 7; // for class methods - this method name needs to be included in the FunctionInput - map method_webhook_configs = 8; + map method_definitions = 8; } message FunctionPrecreateResponse { @@ -2391,10 +2391,11 @@ message VolumeRemoveFileRequest { } message MethodDefinition { - Function.FunctionType function_type = 1; - WebhookConfig webhook_config = 2; - string web_url = 3; - WebUrlInfo web_url_info = 4; + string function_name = 1; + Function.FunctionType function_type = 2; + WebhookConfig webhook_config = 3; + string web_url = 4; + WebUrlInfo web_url_info = 5; } message WebUrlInfo { From f95b8a3adb6de4b4856d617b6760299a04b38d5b Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 28 Oct 2024 01:16:24 +0000 Subject: [PATCH 15/96] fixes --- modal/cls.py | 2 +- modal/functions.py | 31 ++++++++++++++++++------------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index 4d2a5d428..5d018b1e0 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -328,7 +328,7 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio # Get all callables callables: Dict[str, Callable] = { - k: pf.raw_f for k, pf in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags(0)).items() + k: pf.raw_f for k, pf in _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.all()).items() } def _deps() -> List[_Function]: diff --git a/modal/functions.py b/modal/functions.py index dd221f1c8..04d38068a 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -319,7 +319,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type # when this is the method of a class/object function, invocation of this function # should be using another function id and supply the method name in the FunctionInput: # _use_function_id: str # The function to invoke - # _use_method_name: str = "" + _use_method_name: str = "" # TODO (elias): remove _parent. In case of instance functions, and methods bound on those, # this references the parent class-function and is used to infer the client for lazy-loaded methods @@ -745,15 +745,10 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti ) if method_definitions: req.method_definitions = method_definitions - # for method_name, method_def in method_definitions.items(): - # req.method_webhook_configs[method_name].CopyFrom(method_def.webhook_config) else: req.webhook_config = webhook_config response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) - for method_name, method_function in self._method_functions.items(): - method_handle_metadata = response.handle_metadata.method_handle_metadata[method_name] - method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): from ._output import FunctionCreationStatus # Deferred import to avoid Rich dependency in container @@ -951,9 +946,6 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona function_creation_status.set_response(response) self._hydrate(response.function_id, resolver.client, response.handle_metadata) - for method_name, method_function in self._method_functions.items(): - method_handle_metadata = response.handle_metadata.method_handle_metadata[method_name] - method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) @@ -1205,7 +1197,8 @@ def _initialize_from_empty(self): self._function_name = None self._info = None self._all_mounts = [] # used for file watching - self._use_function_id = "" + # self._use_function_id = "" + self._use_method_name = "" def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method @@ -1215,12 +1208,23 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._function_name = metadata.function_name self._is_method = metadata.is_method # self._use_function_id = metadata.use_function_id - # self._use_method_name = metadata.use_method_name + self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info self._definition_id = metadata.definition_id + for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): + method_function = self._method_functions[method_name] + method_function._is_generator = ( + method_handle_metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR + ) + method_function._web_url = method_handle_metadata.web_url + method_function._function_name = method_handle_metadata.function_name + method_function._is_method = method_handle_metadata.is_method + method_function._use_method_name = method_handle_metadata.use_method_name + method_function._definition_id = method_handle_metadata.definition_id def _invocation_function_id(self) -> str: - return self._use_function_id or self.object_id + # return self._use_function_id or self.object_id + return self.object_id def _get_metadata(self): # Overridden concrete implementation of base class method @@ -1230,7 +1234,7 @@ def _get_metadata(self): function_name=self._function_name, function_type=get_function_type(self._is_generator), web_url=self._web_url or "", - # use_method_name=self._use_method_name, + use_method_name=self._use_method_name, # use_function_id=self._use_function_id, is_method=self._is_method, class_parameter_info=self._class_parameter_info, @@ -1242,6 +1246,7 @@ def _get_metadata(self): web_url=method_function._web_url or "", is_method=method_function._is_method, definition_id=method_function._definition_id, + use_method_name=method_function._use_method_name, ) for method_function in self._method_functions.values() ], From 916eb9feeb9c4de198ed441ea278760450e18871 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 28 Oct 2024 02:16:43 +0000 Subject: [PATCH 16/96] fixes --- modal/functions.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 04d38068a..3306a6d50 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -744,7 +744,8 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti existing_function_id=existing_object_id or "", ) if method_definitions: - req.method_definitions = method_definitions + for method_name, method_definition in method_definitions.items(): + req.method_definitions[method_name].CopyFrom(method_definition) else: req.webhook_config = webhook_config response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) @@ -1220,7 +1221,7 @@ def _hydrate_metadata(self, metadata: Optional[Message]): method_function._function_name = method_handle_metadata.function_name method_function._is_method = method_handle_metadata.is_method method_function._use_method_name = method_handle_metadata.use_method_name - method_function._definition_id = method_handle_metadata.definition_id + # method_function._definition_id = method_handle_metadata.definition_id def _invocation_function_id(self) -> str: # return self._use_function_id or self.object_id @@ -1239,17 +1240,17 @@ def _get_metadata(self): is_method=self._is_method, class_parameter_info=self._class_parameter_info, definition_id=self._definition_id, - method_handle_metadata=[ - api_pb2.FunctionHandleMetadata( + method_handle_metadata={ + method_name: api_pb2.FunctionHandleMetadata( function_name=method_function._function_name, function_type=get_function_type(method_function._is_generator), web_url=method_function._web_url or "", is_method=method_function._is_method, - definition_id=method_function._definition_id, + # definition_id=method_function._definition_id, use_method_name=method_function._use_method_name, ) - for method_function in self._method_functions.values() - ], + for method_name, method_function in self._method_functions.items() + }, ) def _check_no_web_url(self, fn_name: str): From 59b7668b91441c651874889dd2ac13362a673a0f Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 28 Oct 2024 18:32:03 +0000 Subject: [PATCH 17/96] functions cleanup --- modal/functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 3306a6d50..c8e392093 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -705,6 +705,8 @@ def from_args( ) method_definitions[method_name] = method_definition + function_type = get_function_type(is_generator) + def _deps(only_explicit_mounts=False) -> List[_Object]: deps: List[_Object] = list(secrets) if only_explicit_mounts: @@ -734,7 +736,6 @@ def _deps(only_explicit_mounts=False) -> List[_Object]: async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub - function_type = get_function_type(is_generator) assert resolver.app_id req = api_pb2.FunctionPrecreateRequest( @@ -756,11 +757,6 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona assert resolver.client and resolver.client.stub with FunctionCreationStatus(resolver, tag) as function_creation_status: - if is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION - timeout_secs = timeout if app and app.is_interactive and not is_builder_function: From 2c455271f0116842e99854a1681b1ab7549fff22 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 28 Oct 2024 22:30:46 +0000 Subject: [PATCH 18/96] CopyFrom fix --- modal/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index c8e392093..004568dc1 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -748,7 +748,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti for method_name, method_definition in method_definitions.items(): req.method_definitions[method_name].CopyFrom(method_definition) else: - req.webhook_config = webhook_config + req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) @@ -1196,6 +1196,7 @@ def _initialize_from_empty(self): self._all_mounts = [] # used for file watching # self._use_function_id = "" self._use_method_name = "" + self._method_functions = {} def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method From 0e5650178df77a581a993136121fe541b2f90026 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 28 Oct 2024 23:34:17 +0000 Subject: [PATCH 19/96] fixes from testing --- modal/functions.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 004568dc1..84b5cb84b 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -751,6 +751,10 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) + for method_name, method_function in self._method_functions.items(): + method_function._hydrate( + response.function_id, resolver.client, response.handle_metadata.method_handle_metadata[method_name] + ) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): from ._output import FunctionCreationStatus # Deferred import to avoid Rich dependency in container @@ -943,6 +947,10 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona function_creation_status.set_response(response) self._hydrate(response.function_id, resolver.client, response.handle_metadata) + for method_name, method_function in self._method_functions.items(): + method_function._hydrate( + response.function_id, resolver.client, response.handle_metadata.method_handle_metadata[method_name] + ) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) @@ -1209,16 +1217,16 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info self._definition_id = metadata.definition_id - for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): - method_function = self._method_functions[method_name] - method_function._is_generator = ( - method_handle_metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR - ) - method_function._web_url = method_handle_metadata.web_url - method_function._function_name = method_handle_metadata.function_name - method_function._is_method = method_handle_metadata.is_method - method_function._use_method_name = method_handle_metadata.use_method_name - # method_function._definition_id = method_handle_metadata.definition_id + # for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): + # method_function = self._method_functions[method_name] + # method_function._is_generator = ( + # method_handle_metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR + # ) + # method_function._web_url = method_handle_metadata.web_url + # method_function._function_name = method_handle_metadata.function_name + # method_function._is_method = method_handle_metadata.is_method + # method_function._use_method_name = method_handle_metadata.use_method_name + # method_function._definition_id = method_handle_metadata.definition_id def _invocation_function_id(self) -> str: # return self._use_function_id or self.object_id From 80e751dca1b9bfaa2d1037917f5f6112ec6e89d7 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 29 Oct 2024 18:13:50 +0000 Subject: [PATCH 20/96] Fix Cls.lookup --- modal/cls.py | 12 ++++++------ modal/functions.py | 34 +++++++++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index be62bea70..d57fc39b9 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -237,7 +237,7 @@ class _Cls(_Object, type_prefix="cs"): _class_service_function: Optional[ _Function ] # The _Function serving *all* methods of the class, used for version >=v0.63 - _method_functions: Dict[str, _Function] # Placeholder _Functions for each method + # _method_functions: Dict[str, _Function] # Placeholder _Functions for each method _options: Optional[api_pb2.FunctionOptions] _callables: Dict[str, Callable[..., Any]] _from_other_workspace: Optional[bool] # Functions require FunctionBindParams before invocation. @@ -246,7 +246,7 @@ class _Cls(_Object, type_prefix="cs"): def _initialize_from_empty(self): self._user_cls = None self._class_service_function = None - self._method_functions = {} + # self._method_functions = {} self._options = None self._callables = {} self._from_other_workspace = None @@ -254,7 +254,7 @@ def _initialize_from_empty(self): def _initialize_from_other(self, other: "_Cls"): self._user_cls = other._user_cls self._class_service_function = other._class_service_function - self._method_functions = other._method_functions + # self._method_functions = other._method_functions self._options = other._options self._callables = other._callables self._from_other_workspace = other._from_other_workspace @@ -384,7 +384,7 @@ def _uses_common_service_function(self): def from_name( cls: Type["_Cls"], app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, environment_name: Optional[str] = None, workspace: Optional[str] = None, @@ -432,7 +432,7 @@ async def _load_remote(obj: _Object, resolver: Resolver, existing_object_id: Opt obj._hydrate(response.class_id, resolver.client, response.handle_metadata) - rep = f"Ref({app_name})" + rep = f"Ref({tag})" cls = cls._from_loader(_load_remote, rep, is_another_app=True) cls._from_other_workspace = bool(workspace is not None) return cls @@ -499,7 +499,7 @@ def with_options( @staticmethod async def lookup( app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, client: Optional[_Client] = None, environment_name: Optional[str] = None, diff --git a/modal/functions.py b/modal/functions.py index 84b5cb84b..4bbaf763a 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1100,7 +1100,7 @@ async def keep_warm(self, warm_pool_size: int) -> None: def from_name( cls: Type["_Function"], app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, environment_name: Optional[str] = None, ) -> "_Function": @@ -1115,7 +1115,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: assert resolver.client and resolver.client.stub request = api_pb2.FunctionGetRequest( app_name=app_name, - object_tag=tag or "", + object_tag=tag, namespace=namespace, environment_name=_get_environment_name(environment_name, resolver) or "", ) @@ -1128,14 +1128,38 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: raise self._hydrate(response.function_id, resolver.client, response.handle_metadata) - - rep = f"Ref({app_name})" + for method_name, method_handle_metadata in response.handle_metadata.method_handle_metadata.items(): + # Construct the method function + async def _load( + method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str] + ): + pass + + class_name = tag[:-2] # remove the .* suffix from the class service function tag to get the class name + full_name = f"{class_name}.{method_name}" + rep = f"Method({full_name})" + method_function = _Function._from_loader(_load, rep) + method_function._tag = full_name + # fun._raw_f = partial_function.raw_f + # fun._info = FunctionInfo( + # partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized() + # ) # needed for .local() + method_function._use_method_name = method_name + method_function._app = self._app + # fun._is_generator = partial_function.is_generator + method_function._all_mounts = self._all_mounts + method_function._spec = self._spec + method_function._is_method = True + method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) + self._method_functions[method_name] = method_function + + rep = f"Ref({tag})" return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True) @staticmethod async def lookup( app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, client: Optional[_Client] = None, environment_name: Optional[str] = None, From 0d2e96e0c3be3246dad164fbe5a545a9bcefb880 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 29 Oct 2024 21:05:09 +0000 Subject: [PATCH 21/96] proto changes for class_get_handle_metadata to handle old clients --- modal_proto/api.proto | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 0f55c950a..cd0025e2d 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -346,6 +346,7 @@ message AppGetObjectsItem { message AppGetObjectsRequest { string app_id = 1; bool include_unindexed = 2; + bool only_class_function = 3; } message AppGetObjectsResponse { @@ -600,6 +601,7 @@ message ClassCreateRequest { string existing_class_id = 2; repeated ClassMethod methods = 3; reserved 4; // removed class_function_id + bool only_class_function = 5; } message ClassCreateResponse { @@ -615,6 +617,7 @@ message ClassGetRequest { bool lookup_published = 8; // Lookup class on app published by another workspace string workspace_name = 9; + bool only_class_function = 10; } message ClassGetResponse { From 3d6f1207eb77e551d22b0e9620accd552733b56f Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 29 Oct 2024 22:30:57 +0000 Subject: [PATCH 22/96] remove extraneous space --- modal_proto/api.proto | 1 - 1 file changed, 1 deletion(-) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index cd0025e2d..862dae931 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -1088,7 +1088,6 @@ message Function { FUNCTION_TYPE_GENERATOR = 1; FUNCTION_TYPE_FUNCTION = 2; } - FunctionType function_type = 8; Resources resources = 9; From f0aad3feab78f77dd94c4fcda798eea7b8b4cd08 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 29 Oct 2024 23:37:28 +0000 Subject: [PATCH 23/96] add method_definitions_set to function proto --- modal/functions.py | 1 + modal_proto/api.proto | 1 + 2 files changed, 2 insertions(+) diff --git a/modal/functions.py b/modal/functions.py index 4bbaf763a..02d5b1ac8 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -832,6 +832,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona function_type=function_type, webhook_config=webhook_config, method_definitions=method_definitions, + method_definitions_set=True, shared_volume_mounts=network_file_system_mount_protos( validated_network_file_systems, allow_cross_region_volumes ), diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 862dae931..73487d1b6 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -1202,6 +1202,7 @@ message Function { // Need a mapping of method names to method definitions map method_definitions = 74; + bool method_definitions_set = 75; } message FunctionBindParamsRequest { From 735792fafca5b6a05fed2fde71c8b6c801052d67 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 29 Oct 2024 23:58:26 +0000 Subject: [PATCH 24/96] Include only_class_function in requests --- modal/_container_io_manager.py | 2 +- modal/cls.py | 5 ++++- modal/runner.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 92f099163..c489ee374 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -467,7 +467,7 @@ async def _dynamic_concurrency_loop(self): await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS) async def get_app_objects(self) -> RunningApp: - req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True) + req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True, only_class_function=True) resp = await retry_transient_errors(self._client.stub.AppGetObjects, req) logger.debug(f"AppGetObjects received {len(resp.items)} objects for app {self.app_id}") diff --git a/modal/cls.py b/modal/cls.py index d57fc39b9..cb227859b 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -345,7 +345,9 @@ def _deps() -> List[_Function]: return [class_service_function] async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[str]): - req = api_pb2.ClassCreateRequest(app_id=resolver.app_id, existing_class_id=existing_object_id) + req = api_pb2.ClassCreateRequest( + app_id=resolver.app_id, existing_class_id=existing_object_id, only_class_function=True + ) # for f_name, f in self._method_functions.items(): # req.methods.append( # api_pb2.ClassMethod( @@ -405,6 +407,7 @@ async def _load_remote(obj: _Object, resolver: Resolver, existing_object_id: Opt environment_name=_environment_name, lookup_published=workspace is not None, workspace_name=workspace, + only_class_function=True, ) try: response = await retry_transient_errors(resolver.client.stub.ClassGet, request) diff --git a/modal/runner.py b/modal/runner.py index b30bca8b7..79f32afd4 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -58,7 +58,7 @@ async def _heartbeat(client: _Client, app_id: str) -> None: async def _init_local_app_existing(client: _Client, existing_app_id: str, environment_name: str) -> RunningApp: # Get all the objects first - obj_req = api_pb2.AppGetObjectsRequest(app_id=existing_app_id) + obj_req = api_pb2.AppGetObjectsRequest(app_id=existing_app_id, only_class_function=True) obj_resp, _ = await asyncio.gather( retry_transient_errors(client.stub.AppGetObjects, obj_req), # Cache the environment associated with the app now as we will use it later From 4f0b2b1d4872b3a340f5718e331be39db4b2161f Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 30 Oct 2024 23:38:13 +0000 Subject: [PATCH 25/96] support lookup from new clients --- modal/cls.py | 28 ++++++++++++++++------------ modal/functions.py | 3 +++ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index cb227859b..e36599b96 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -266,18 +266,22 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) - - # for method in metadata.methods: - # if method.function_name in self._method_functions: - # # This happens when the class is loaded locally - # # since each function will already be a loaded dependency _Function - # self._method_functions[method.function_name]._hydrate( - # method.function_id, self._client, method.function_handle_metadata - # ) - # else: - # self._method_functions[method.function_name] = _Function._new_hydrated( - # method.function_id, self._client, method.function_handle_metadata - # ) + if self._class_service_function: + if self._class_service_function._method_functions: + # The class only has a class service service function and no method placeholders. + return + else: + for method in metadata.methods: + if method.function_name in self._class_service_function._method_functions: + # This happens when the class is loaded locally + # since each function will already be a loaded dependency _Function + self._class_service_function._method_functions[method.function_name]._hydrate( + method.function_id, self._client, method.function_handle_metadata + ) + else: + self._class_service_function._method_functions[method.function_name] = _Function._new_hydrated( + method.function_id, self._client, method.function_handle_metadata + ) def _get_metadata(self) -> api_pb2.ClassHandleMetadata: class_handle_metadata = api_pb2.ClassHandleMetadata() diff --git a/modal/functions.py b/modal/functions.py index 02d5b1ac8..1f252f0dd 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -315,6 +315,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _build_args: dict _can_use_base_function: bool = False # whether we need to call FunctionBindParams _is_generator: Optional[bool] = None + _definition_id: str # when this is the method of a class/object function, invocation of this function # should be using another function id and supply the method name in the FunctionInput: @@ -1151,6 +1152,7 @@ async def _load( method_function._all_mounts = self._all_mounts method_function._spec = self._spec method_function._is_method = True + method_function._definition_id = self._definition_id method_function._hydrate(response.function_id, resolver.client, method_handle_metadata) self._method_functions[method_name] = method_function @@ -1230,6 +1232,7 @@ def _initialize_from_empty(self): # self._use_function_id = "" self._use_method_name = "" self._method_functions = {} + self._definition_id = "" def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method From 6fcb498f5e8ba5317a535d4eb590cb141b52d550 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 31 Oct 2024 16:05:11 +0000 Subject: [PATCH 26/96] Support new client looking up 0.62 class --- modal/cls.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modal/cls.py b/modal/cls.py index e36599b96..b041017f5 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -267,6 +267,7 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) if self._class_service_function: + print("foo") if self._class_service_function._method_functions: # The class only has a class service service function and no method placeholders. return @@ -282,6 +283,17 @@ def _hydrate_metadata(self, metadata: Message): self._class_service_function._method_functions[method.function_name] = _Function._new_hydrated( method.function_id, self._client, method.function_handle_metadata ) + else: + # We are dealing with a pre 0.63 class that does not have a class service function and only method functions + async def _load(): + pass + + rep = "Function(class_function)" + self._class_service_function = _Function._from_loader(_load, rep) + for method in metadata.methods: + self._class_service_function._method_functions[method.function_name] = _Function._new_hydrated( + method.function_id, self._client, method.function_handle_metadata + ) def _get_metadata(self) -> api_pb2.ClassHandleMetadata: class_handle_metadata = api_pb2.ClassHandleMetadata() From 7215fba778e24424546d5e3bc08995748854e76e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 31 Oct 2024 21:09:19 +0000 Subject: [PATCH 27/96] obj init cleanup --- modal/cls.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index b041017f5..a5d559658 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -91,7 +91,7 @@ def __init__( self, user_cls: type, class_service_function: Optional[_Function], # only None for = v0.63 classes # first create the singular object function used by all methods on this parameterization self._instance_service_function = class_service_function._bind_parameters( self, from_other_workspace, options, args, kwargs ) - for method_name, class_bound_method in classbound_methods.items(): + for method_name, class_bound_method in class_service_function._method_functions.items(): method = self._instance_service_function._bind_instance_method(class_bound_method) self._method_functions[method_name] = method else: # Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) if self._class_service_function: - print("foo") if self._class_service_function._method_functions: # The class only has a class service service function and no method placeholders. return @@ -542,7 +541,7 @@ def __call__(self, *args, **kwargs) -> _Obj: return _Obj( self._user_cls, self._class_service_function, - self._class_service_function._method_functions, + # self._class_service_function._method_functions, self._from_other_workspace, self._options, args, From 1b021924f3648838e9fed3fe2210e55392e69cc6 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 4 Nov 2024 15:47:20 +0000 Subject: [PATCH 28/96] Remove duplicate MessageDefinition --- modal_proto/api.proto | 8 -------- 1 file changed, 8 deletions(-) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 1bb1011fc..c173e903f 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -2536,14 +2536,6 @@ message VolumeRemoveFileRequest { bool recursive = 3; } -message MethodDefinition { - string function_name = 1; - Function.FunctionType function_type = 2; - WebhookConfig webhook_config = 3; - string web_url = 4; - WebUrlInfo web_url_info = 5; -} - message WebUrlInfo { bool truncated = 1; bool has_unique_hash = 2 [deprecated=true]; From c63621fb346ab2ecbc0c27ffdc5d194cf226aaaf Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 6 Nov 2024 20:33:19 +0000 Subject: [PATCH 29/96] Make tag parameter non-optional in from_name and lookup for Function and Cls --- modal/cls.py | 4 ++-- modal/functions.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index 06078c5a2..7d1bb043d 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -383,7 +383,7 @@ def _uses_common_service_function(self): def from_name( cls: Type["_Cls"], app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, environment_name: Optional[str] = None, workspace: Optional[str] = None, @@ -498,7 +498,7 @@ def with_options( @staticmethod async def lookup( app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, client: Optional[_Client] = None, environment_name: Optional[str] = None, diff --git a/modal/functions.py b/modal/functions.py index 37100c314..06878c4f7 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1057,7 +1057,7 @@ async def keep_warm(self, warm_pool_size: int) -> None: def from_name( cls: Type["_Function"], app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, environment_name: Optional[str] = None, ) -> "_Function": @@ -1072,7 +1072,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: assert resolver.client and resolver.client.stub request = api_pb2.FunctionGetRequest( app_name=app_name, - object_tag=tag or "", + object_tag=tag, namespace=namespace, environment_name=_get_environment_name(environment_name, resolver) or "", ) @@ -1092,7 +1092,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: @staticmethod async def lookup( app_name: str, - tag: Optional[str] = None, + tag: str, namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, client: Optional[_Client] = None, environment_name: Optional[str] = None, From 289d05b4c7df2e02e77f114c4dbf127c47de75a8 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 13 Nov 2024 22:33:35 +0000 Subject: [PATCH 30/96] Add util to get api_pb2.Function.FunctionType from is_generator --- modal/_utils/function_utils.py | 4 ++++ modal/functions.py | 17 ++++------------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 7f718a9e6..d416037e6 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -93,6 +93,10 @@ def is_async(function): raise RuntimeError(f"Function {function} is a strange type {type(function)}") +def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType: + return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION + + class FunctionInfo: """Class that helps us extract a bunch of information about a function.""" diff --git a/modal/functions.py b/modal/functions.py index 08dd09197..f23dfe894 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -52,6 +52,7 @@ _create_input, _process_result, _stream_function_call_data, + get_function_type, is_async, ) from ._utils.grpc_utils import retry_transient_errors @@ -350,11 +351,7 @@ def _bind_method( assert not class_service_function._is_method # should not be used on an already bound method placeholder assert not class_service_function._obj # should only be used on base function / class service function full_name = f"{user_cls.__name__}.{method_name}" - - if partial_function.is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = get_function_type(partial_function.is_generator) async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): function_definition = api_pb2.Function( @@ -714,10 +711,7 @@ def _deps(only_explicit_mounts=False) -> List[_Object]: async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub - if is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = get_function_type(is_generator) assert resolver.app_id req = api_pb2.FunctionPrecreateRequest( @@ -733,10 +727,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub with FunctionCreationStatus(resolver, tag) as function_creation_status: - if is_generator: - function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR - else: - function_type = api_pb2.Function.FUNCTION_TYPE_FUNCTION + function_type = get_function_type(is_generator) timeout_secs = timeout From 1fd397a460bbbf3f7fe308f6d15f42ecf874655e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 13 Nov 2024 22:58:48 +0000 Subject: [PATCH 31/96] fix type check --- modal/_utils/function_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index d416037e6..a68ccb966 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -93,7 +93,7 @@ def is_async(function): raise RuntimeError(f"Function {function} is a strange type {type(function)}") -def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType: +def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType.ValueType: return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION From 2cf1dfbf5f93f1bc5c9f9be1c5631dab0438ee21 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 13 Nov 2024 22:59:53 +0000 Subject: [PATCH 32/96] fix type check --- modal/_utils/function_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index d416037e6..a68ccb966 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -93,7 +93,7 @@ def is_async(function): raise RuntimeError(f"Function {function} is a strange type {type(function)}") -def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType: +def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType.ValueType: return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION From 6820379fa7a3d2d2aa53c6d891edf2839062f6c6 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 13 Nov 2024 23:04:54 +0000 Subject: [PATCH 33/96] is_generator can be None --- modal/_utils/function_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index a68ccb966..b2995c9ae 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -93,7 +93,7 @@ def is_async(function): raise RuntimeError(f"Function {function} is a strange type {type(function)}") -def get_function_type(is_generator: bool) -> api_pb2.Function.FunctionType.ValueType: +def get_function_type(is_generator: Optional[bool]) -> api_pb2.Function.FunctionType.ValueType: return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION From 8603419063855247e101213a180bfb75989c97b9 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 13 Nov 2024 23:10:28 +0000 Subject: [PATCH 34/96] add quotes --- modal/_utils/function_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index b2995c9ae..c2d84569b 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -93,7 +93,7 @@ def is_async(function): raise RuntimeError(f"Function {function} is a strange type {type(function)}") -def get_function_type(is_generator: Optional[bool]) -> api_pb2.Function.FunctionType.ValueType: +def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.FunctionType.ValueType": return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION From 1571203b8379e2667d3a90c50a4acc5f9892a9b0 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 15:43:23 +0000 Subject: [PATCH 35/96] cleanup --- modal/functions.py | 58 +--------------------------------------------- 1 file changed, 1 insertion(+), 57 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 6d5b11233..8d20016da 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -321,8 +321,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _cluster_size: Optional[int] = None # when this is the method of a class/object function, invocation of this function - # should be using another function id and supply the method name in the FunctionInput: - # _use_function_id: str # The function to invoke + # should supply the method name in the FunctionInput: _use_method_name: str = "" # TODO (elias): remove _parent. In case of instance functions, and methods bound on those, @@ -354,63 +353,11 @@ def _bind_method( assert not class_service_function._obj # should only be used on base function / class service function full_name = f"{user_cls.__name__}.{method_name}" - # async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): - # from ._output import FunctionCreationStatus # Deferred import to avoid Rich dependency in container - - # function_definition = api_pb2.Function( - # function_name=full_name, - # webhook_config=partial_function.webhook_config, - # function_type=function_type, - # is_method=True, - # use_function_id=class_service_function.object_id, - # use_method_name=method_name, - # batch_max_size=partial_function.batch_max_size or 0, - # batch_linger_ms=partial_function.batch_wait_ms or 0, - # ) - # assert resolver.app_id - # request = api_pb2.FunctionCreateRequest( - # app_id=resolver.app_id, - # function=function_definition, - # # method_bound_function.object_id usually gets set by preload - # existing_function_id=existing_object_id or method_bound_function.object_id or "", - # defer_updates=True, - # ) - # assert resolver.client.stub is not None # client should be connected when load is called - # with FunctionCreationStatus(resolver, full_name) as function_creation_status: - # response = await resolver.client.stub.FunctionCreate(request) - # method_bound_function._hydrate( - # response.function_id, - # resolver.client, - # response.handle_metadata, - # ) - # function_creation_status.set_response(response) - - # async def _preload(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): - # if class_service_function._use_method_name: - # raise ExecutionError(f"Can't bind method to already bound {class_service_function}") - # assert resolver.app_id - # req = api_pb2.FunctionPrecreateRequest( - # app_id=resolver.app_id, - # function_name=full_name, - # function_type=function_type, - # webhook_config=partial_function.webhook_config, - # use_function_id=class_service_function.object_id, - # use_method_name=method_name, - # existing_function_id=existing_object_id or "", - # ) - # assert resolver.client.stub # client should be connected at this point - # response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) - # method_bound_function._hydrate(response.function_id, resolver.client, response.handle_metadata) - - # def _deps(): - # return [class_service_function] - async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): pass rep = f"Method({full_name})" fun = _Function._from_loader(_load, rep) - # fun = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) fun._tag = full_name fun._raw_f = partial_function.raw_f fun._info = FunctionInfo( @@ -600,9 +547,6 @@ def from_args( ) if info.user_cls and not is_auto_snapshot: - # # Needed to avoid circular imports - # from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags - build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items() for k, pf in build_functions: build_function = pf.raw_f From 32ee93a24809fbb9886bd380edf760806467c0ad Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 16:18:23 +0000 Subject: [PATCH 36/96] Send method definitions in FunctionPrecreateRequest --- modal/functions.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 4e58f9e7e..3e46126ee 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -523,6 +523,9 @@ def from_args( _experimental_custom_scaling_factor: Optional[float] = None, ) -> None: """mdmd:hidden""" + # Needed to avoid circular imports + from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags + tag = info.get_tag() if info.raw_f: @@ -591,9 +594,6 @@ def from_args( ) if info.user_cls and not is_auto_snapshot: - # Needed to avoid circular imports - from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags - build_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.BUILD).items() for k, pf in build_functions: build_function = pf.raw_f @@ -682,6 +682,22 @@ def from_args( if image is not None and not isinstance(image, _Image): raise InvalidError(f"Expected modal.Image object. Got {type(image)}.") + method_definitions: Optional[Dict[str, api_pb2.MethodDefinition]] = None + if info.user_cls: + method_definitions = {} + partial_functions: Dict[ + str, "modal.partial_function._PartialFunction" + ] = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION) + for method_name, partial_function in partial_functions.items(): + function_type = get_function_type(partial_function.is_generator) + function_name = f"{info.user_cls.__name__}.{method_name}" + method_definition = api_pb2.MethodDefinition( + webhook_config=partial_function.webhook_config, + function_type=function_type, + function_name=function_name, + ) + method_definitions[method_name] = method_definition + def _deps(only_explicit_mounts=False) -> List[_Object]: deps: List[_Object] = list(secrets) if only_explicit_mounts: @@ -718,9 +734,13 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti app_id=resolver.app_id, function_name=info.function_name, function_type=function_type, - webhook_config=webhook_config, existing_function_id=existing_object_id or "", ) + if method_definitions: + for method_name, method_definition in method_definitions.items(): + req.method_definitions[method_name].CopyFrom(method_definition) + else: + req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) From 5197215433d3f54d1a6381748a07ea6f8e067d2b Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 16:29:49 +0000 Subject: [PATCH 37/96] Can't pass null webhook_config to CopyFrom --- modal/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index 3e46126ee..4b94712d4 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -739,7 +739,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti if method_definitions: for method_name, method_definition in method_definitions.items(): req.method_definitions[method_name].CopyFrom(method_definition) - else: + elif req.webhook_config: req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) From bcdf8ebb07fde3ad3e7a7ac927f04faacebed9b3 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 16:32:13 +0000 Subject: [PATCH 38/96] fix typo --- modal/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index 4b94712d4..712c68ec4 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -739,7 +739,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti if method_definitions: for method_name, method_definition in method_definitions.items(): req.method_definitions[method_name].CopyFrom(method_definition) - elif req.webhook_config: + elif webhook_config: req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) self._hydrate(response.function_id, resolver.client, response.handle_metadata) From e3540993944ebc16a61fbbb04994304cf5fdd395 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 19:51:00 +0000 Subject: [PATCH 39/96] Add _method_functions attribute to _Function --- modal/cls.py | 2 +- modal/functions.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/modal/cls.py b/modal/cls.py index 1ca19bf3b..8b37a27fc 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -326,7 +326,7 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio ) for method_name, partial_function in partial_functions.items(): - method_function = class_service_function._bind_method(user_cls, method_name, partial_function) + method_function = class_service_function._bind_method_old(user_cls, method_name, partial_function) app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) partial_function.wrapped = True functions[method_name] = method_function diff --git a/modal/functions.py b/modal/functions.py index 712c68ec4..6f2c3665e 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -329,6 +329,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _parent: Optional["_Function"] = None _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None + _method_functions: Optional[Dict[str, "_Function"]] = None # Placeholder _Functions for each method def _bind_method( self, @@ -345,6 +346,44 @@ def _bind_method( we don't create an actual backend function, and instead do client-side "fake-hydration" only, see _bind_instance_method. + """ + class_service_function = self + assert class_service_function._info # has to be a local function to be able to "bind" it + assert not class_service_function._is_method # should not be used on an already bound method placeholder + assert not class_service_function._obj # should only be used on base function / class service function + full_name = f"{user_cls.__name__}.{method_name}" + + rep = f"Method({full_name})" + fun = _Object.__new__(_Function) + fun._init(rep) + fun._tag = full_name + fun._raw_f = partial_function.raw_f + fun._info = FunctionInfo( + partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized() + ) # needed for .local() + fun._use_method_name = method_name + fun._app = class_service_function._app + fun._is_generator = partial_function.is_generator + fun._cluster_size = partial_function.cluster_size + fun._spec = class_service_function._spec + fun._is_method = True + return fun + + def _bind_method_old( + self, + user_cls, + method_name: str, + partial_function: "modal.partial_function._PartialFunction", + ): + """mdmd:hidden + + Creates a function placeholder function that binds a specific method name to + this function for use when invoking the function. + + Should only be used on "class service functions". For "instance service functions", + we don't create an actual backend function, and instead do client-side "fake-hydration" + only, see _bind_instance_method. + """ class_service_function = self assert class_service_function._info # has to be a local function to be able to "bind" it @@ -950,6 +989,12 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._is_method = False obj._spec = function_spec # needed for modal shell + if info.user_cls: + obj._method_functions = {} + for method_name, partial_function in partial_functions.items(): + method_function = obj._bind_method(info.user_cls, method_name, partial_function) + obj._method_functions[method_name] = method_function + # Used to check whether we should rebuild a modal.Image which uses `run_function`. gpus: List[GPU_T] = gpu if isinstance(gpu, list) else [gpu] obj._build_args = dict( # See get_build_def From 5cab92af6f4e9620487e31388c8d682232818156 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 20:05:15 +0000 Subject: [PATCH 40/96] fix unbound error --- modal/functions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 6f2c3665e..4e0ae6fe5 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -722,11 +722,10 @@ def from_args( raise InvalidError(f"Expected modal.Image object. Got {type(image)}.") method_definitions: Optional[Dict[str, api_pb2.MethodDefinition]] = None + partial_functions: Dict[str, "modal.partial_function._PartialFunction"] = {} if info.user_cls: method_definitions = {} - partial_functions: Dict[ - str, "modal.partial_function._PartialFunction" - ] = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION) + partial_functions = _find_partial_methods_for_user_cls(info.user_cls, _PartialFunctionFlags.FUNCTION) for method_name, partial_function in partial_functions.items(): function_type = get_function_type(partial_function.is_generator) function_name = f"{info.user_cls.__name__}.{method_name}" From 2e1cce67f2ce1a708bf730b49f4fe346219e5b1e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 20:24:49 +0000 Subject: [PATCH 41/96] update comment --- modal/functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 4e0ae6fe5..462e60e4a 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -339,12 +339,8 @@ def _bind_method( ): """mdmd:hidden - Creates a function placeholder function that binds a specific method name to - this function for use when invoking the function. - - Should only be used on "class service functions". For "instance service functions", - we don't create an actual backend function, and instead do client-side "fake-hydration" - only, see _bind_instance_method. + Creates a _Function that is bound to a specific class method name. This _Function is not uniquely tied + to any backend function -- its object_id is the function ID of the class service function. """ class_service_function = self From d16ccc726ddd50a146a90a46fb4524fbcfe87d27 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 20:43:59 +0000 Subject: [PATCH 42/96] remove bind_method_old --- modal/cls.py | 2 +- modal/functions.py | 40 ---------------------------------------- 2 files changed, 1 insertion(+), 41 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index 0fd7319aa..ff241cc2d 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -341,7 +341,7 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio ) for partial_function in partial_functions.values(): - # method_function = class_service_function._bind_method(user_cls, method_name, partial_function) + # method_function = class_service_function._bind_method_old(user_cls, method_name, partial_function) # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) partial_function.wrapped = True # functions[method_name] = method_function diff --git a/modal/functions.py b/modal/functions.py index 0d277e89b..bc98867a3 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -365,46 +365,6 @@ def _bind_method( fun._is_method = True return fun - def _bind_method_old( - self, - user_cls, - method_name: str, - partial_function: "modal.partial_function._PartialFunction", - ): - """mdmd:hidden - - Creates a function placeholder function that binds a specific method name to - this function for use when invoking the function. - - Should only be used on "class service functions". For "instance service functions", - we don't create an actual backend function, and instead do client-side "fake-hydration" - only, see _bind_instance_method. - - """ - class_service_function = self - assert class_service_function._info # has to be a local function to be able to "bind" it - assert not class_service_function._is_method # should not be used on an already bound method placeholder - assert not class_service_function._obj # should only be used on base function / class service function - full_name = f"{user_cls.__name__}.{method_name}" - - async def _load(method_bound_function: "_Function", resolver: Resolver, existing_object_id: Optional[str]): - pass - - rep = f"Method({full_name})" - fun = _Function._from_loader(_load, rep) - fun._tag = full_name - fun._raw_f = partial_function.raw_f - fun._info = FunctionInfo( - partial_function.raw_f, user_cls=user_cls, serialized=class_service_function.info.is_serialized() - ) # needed for .local() - fun._use_method_name = method_name - fun._app = class_service_function._app - fun._is_generator = partial_function.is_generator - fun._cluster_size = partial_function.cluster_size - fun._spec = class_service_function._spec - fun._is_method = True - return fun - def _bind_instance_method(self, class_bound_method: "_Function"): """mdmd:hidden From ec894994907ce76c998524ab9cab9aa87b9e1da9 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 21:16:02 +0000 Subject: [PATCH 43/96] cleanup --- modal/functions.py | 18 +----------------- modal/parallel_map.py | 4 ++-- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index bc98867a3..82b4cbaf1 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -118,7 +118,7 @@ async def create( function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType", ) -> "_Invocation": assert client.stub - function_id = function._invocation_function_id() + function_id = function.object_id item = await _create_input(args, kwargs, client, method_name=function._use_method_name) request = api_pb2.FunctionMapRequest( @@ -897,7 +897,6 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._raw_f = info.raw_f obj._info = info - obj._function_name = info.function_name obj._tag = tag obj._app = app # needed for CLI right now obj._obj = None @@ -1184,24 +1183,9 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._web_url = metadata.web_url self._function_name = metadata.function_name self._is_method = metadata.is_method - # self._use_function_id = metadata.use_function_id self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info self._definition_id = metadata.definition_id - # for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): - # method_function = self._method_functions[method_name] - # method_function._is_generator = ( - # method_handle_metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR - # ) - # method_function._web_url = method_handle_metadata.web_url - # method_function._function_name = method_handle_metadata.function_name - # method_function._is_method = method_handle_metadata.is_method - # method_function._use_method_name = method_handle_metadata.use_method_name - # method_function._definition_id = method_handle_metadata.definition_id - - def _invocation_function_id(self) -> str: - # return self._use_function_id or self.object_id - return self.object_id def _get_metadata(self): # Overridden concrete implementation of base class method diff --git a/modal/parallel_map.py b/modal/parallel_map.py index d79ea630d..b8a14dba8 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -78,7 +78,7 @@ async def _map_invocation( ): assert client.stub request = api_pb2.FunctionMapRequest( - function_id=function._invocation_function_id(), + function_id=function.object_id, parent_input_id=current_input_id() or "", function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP, return_exceptions=return_exceptions, @@ -131,7 +131,7 @@ async def pump_inputs(): nonlocal have_all_inputs, num_inputs async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE): request = api_pb2.FunctionPutInputsRequest( - function_id=function._invocation_function_id(), inputs=items, function_call_id=function_call_id + function_id=function.object_id, inputs=items, function_call_id=function_call_id ) logger.debug( f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}." From 9fe8165423dbc84c6da6a7bd452fb3ec124d0e9b Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 14 Nov 2024 22:43:16 +0000 Subject: [PATCH 44/96] Hydrate _method_functions on _Function --- modal/functions.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 7c45a455b..bb460b438 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -517,6 +517,26 @@ def _deps(): fun._spec = class_bound_method._spec return fun + def _hydrate_function_and_method_functions( + self, function_id: str, client: _Client, handle_metadata: api_pb2.FunctionHandleMetadata + ): + self._hydrate(function_id, client, handle_metadata) + if self._method_functions: + # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a + # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects + for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): + if method_name in self._method_functions: + method_function = self._method_functions[method_name] + method_function._hydrate(function_id, client, method_handle_metadata) + elif len(handle_metadata.method_handle_metadata): + # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based + # on the existence of method_handle_metadata that this is a class service function + self._method_functions = {} + for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): + self._method_functions[method_name] = _Function._new_hydrated( + function_id, client, method_handle_metadata + ) + @staticmethod def from_args( info: FunctionInfo, @@ -776,7 +796,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti elif webhook_config: req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) - self._hydrate(response.function_id, resolver.client, response.handle_metadata) + self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub @@ -974,7 +994,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona local_mounts = set(m for m in all_mounts if m.is_local()) # needed for modal.serve file watching local_mounts |= image._used_local_mounts obj._used_local_mounts = frozenset(local_mounts) - self._hydrate(response.function_id, resolver.client, response.handle_metadata) + self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) @@ -1151,7 +1171,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: else: raise - self._hydrate(response.function_id, resolver.client, response.handle_metadata) + self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) rep = f"Ref({app_name})" return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True) From 5b8ae616e047144173259b6650e3d5763427ca1e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 16:38:47 +0000 Subject: [PATCH 45/96] Update mock FunctionPrecreate --- test/conftest.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 976ed2c10..ae3eb13b3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -901,6 +901,16 @@ async def FunctionPrecreate(self, stream): self.precreated_functions.add(function_id) web_url = "http://xyz.internal" if req.HasField("webhook_config") and req.webhook_config.type else None + + # This loop is for class service functions, where req.method_definitions will be non-empty + method_handle_metadata: dict[str, api_pb2.FunctionHandleMetadata] = {} + for method_name, method_definition in req.method_definitions.items(): + method_web_url = f"https://{method_name}.internal" + method_handle_metadata[method_name] = api_pb2.FunctionHandleMetadata( + function_name=method_definition.function_name, + function_type=method_definition.function_type, + web_url=method_web_url, + ) await stream.send_message( api_pb2.FunctionPrecreateResponse( function_id=function_id, @@ -910,6 +920,7 @@ async def FunctionPrecreate(self, stream): web_url=web_url, use_function_id=req.use_function_id or function_id, use_method_name=req.use_method_name, + method_handle_metadata=method_handle_metadata, ), ) ) From 239a23dd51438bb59ac6f6662c81e50714204a08 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 17:38:55 +0000 Subject: [PATCH 46/96] Populate FunctionHandleMetadata.method_handle_metadata in _Function _get_metadata --- modal/functions.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modal/functions.py b/modal/functions.py index 79529b9a3..651648a31 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1278,6 +1278,17 @@ def _get_metadata(self): is_method=self._is_method, class_parameter_info=self._class_parameter_info, definition_id=self._definition_id, + method_handle_metadata={ + method_name: api_pb2.FunctionHandleMetadata( + function_name=method_function._function_name, + function_type=get_function_type(method_function._is_generator), + web_url=method_function._web_url or "", + is_method=method_function._is_method, + definition_id=method_function._definition_id, + use_method_name=method_function._use_method_name, + ) + for method_name, method_function in self._method_functions.items() + }, ) def _check_no_web_url(self, fn_name: str): From a11f647b9be05e57a17ace74ba67a9ef860fb6fd Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 17:45:12 +0000 Subject: [PATCH 47/96] fix type checks --- modal/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index 651648a31..c94dbef10 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1288,7 +1288,10 @@ def _get_metadata(self): use_method_name=method_function._use_method_name, ) for method_name, method_function in self._method_functions.items() - }, + if method_function._function_name + } + if self._method_functions + else None, ) def _check_no_web_url(self, fn_name: str): From 1157219e0d22c2e4b6bc20f62c4f7f7ed4c71062 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 19:18:17 +0000 Subject: [PATCH 48/96] Remove extraneous changes --- modal/functions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 083c1eecc..90cca0de4 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -317,7 +317,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _build_args: dict _can_use_base_function: bool = False # whether we need to call FunctionBindParams _is_generator: Optional[bool] = None - _definition_id: str _cluster_size: Optional[int] = None # when this is the method of a class/object function, invocation of this function @@ -1162,8 +1161,6 @@ def _initialize_from_empty(self): self._web_url = None self._function_name = None self._info = None - self._use_method_name = "" - self._definition_id = "" self._used_local_mounts = frozenset() def _hydrate_metadata(self, metadata: Optional[Message]): From 611ed93daca7209bb0c79f91399ba0246ac5df35 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 19:29:06 +0000 Subject: [PATCH 49/96] cleanup --- modal/cls.py | 35 +---------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index ff241cc2d..d63294ee4 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -90,8 +90,7 @@ def _uses_common_service_function(self): def __init__( self, user_cls: type, - class_service_function: Optional[_Function], # only None for =v0.63 - # _method_functions: Dict[str, _Function] # Placeholder _Functions for each method _options: Optional[api_pb2.FunctionOptions] _callables: Dict[str, Callable[..., Any]] _from_other_workspace: Optional[bool] # Functions require FunctionBindParams before invocation. @@ -246,7 +244,6 @@ class _Cls(_Object, type_prefix="cs"): def _initialize_from_empty(self): self._user_cls = None self._class_service_function = None - # self._method_functions = {} self._options = None self._callables = {} self._from_other_workspace = None @@ -254,7 +251,6 @@ def _initialize_from_empty(self): def _initialize_from_other(self, other: "_Cls"): self._user_cls = other._user_cls self._class_service_function = other._class_service_function - # self._method_functions = other._method_functions self._options = other._options self._callables = other._callables self._from_other_workspace = other._from_other_workspace @@ -296,12 +292,6 @@ async def _load(): def _get_metadata(self) -> api_pb2.ClassHandleMetadata: class_handle_metadata = api_pb2.ClassHandleMetadata() - # for f_name, f in self._method_functions.items(): - # class_handle_metadata.methods.append( - # api_pb2.ClassMethod( - # function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata() - # ) - # ) return class_handle_metadata @staticmethod @@ -335,16 +325,12 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio # validate signature _Cls.validate_construction_mechanism(user_cls) - # functions: Dict[str, _Function] = {} partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( user_cls, _PartialFunctionFlags.FUNCTION ) for partial_function in partial_functions.values(): - # method_function = class_service_function._bind_method_old(user_cls, method_name, partial_function) - # app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) partial_function.wrapped = True - # functions[method_name] = method_function # Disable the warning that these are not wrapped for partial_function in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags.FUNCTION).values(): @@ -356,30 +342,13 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio } def _deps() -> List[_Function]: - # return [class_service_function] + list(functions.values()) return [class_service_function] async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.ClassCreateRequest( app_id=resolver.app_id, existing_class_id=existing_object_id, only_class_function=True ) - # for f_name, f in self._method_functions.items(): - # req.methods.append( - # api_pb2.ClassMethod( - # function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata() - # ) - # ) resp = await resolver.client.stub.ClassCreate(req) - # Even though we already have the function_handle_metadata for this method locally, - # The RPC is going to replace it with function_handle_metadata derived from the server. - # We need to overwrite the definition_id sent back from the server here with the definition_id - # previously stored in function metadata, which may have been sent back from FunctionCreate. - # The problem is that this metadata propagates back and overwrites the metadata on the Function - # object itself. This is really messy. Maybe better to exclusively populate the method metadata - # from the function metadata we already have locally? Really a lot to clean up here... - # for method in resp.handle_metadata.methods: - # f_metadata = self._method_functions[method.function_name]._get_metadata() - # method.function_handle_metadata.definition_id = f_metadata.definition_id self._hydrate(resp.class_id, resolver.client, resp.handle_metadata) rep = f"Cls({user_cls.__name__})" @@ -387,7 +356,6 @@ async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[s cls._app = app cls._user_cls = user_cls cls._class_service_function = class_service_function - # cls._method_functions = functions cls._callables = callables cls._from_other_workspace = False return cls @@ -548,7 +516,6 @@ def __call__(self, *args, **kwargs) -> _Obj: return _Obj( self._user_cls, self._class_service_function, - # self._class_service_function._method_functions, self._from_other_workspace, self._options, args, From 4502e2b16ddcb86f50cef90453b7f453a17fcc71 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 23:09:54 +0000 Subject: [PATCH 50/96] Add _is_web_endpoint to _Function --- modal/app.py | 12 ++++++++---- modal/cls.py | 2 +- modal/functions.py | 8 ++++++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/modal/app.py b/modal/app.py index 125540510..da6576a56 100644 --- a/modal/app.py +++ b/modal/app.py @@ -492,7 +492,7 @@ def _get_watch_mounts(self): return [m for m in all_mounts if m.is_local()] - def _add_function(self, function: _Function, is_web_endpoint: bool): + def _add_function(self, function: _Function): if function.tag in self._indexed_objects: old_function = self._indexed_objects[function.tag] if isinstance(old_function, _Function): @@ -507,8 +507,12 @@ def _add_function(self, function: _Function, is_web_endpoint: bool): logger.warning(f"Warning: tag {function.tag} exists but is overridden by function") self._add_object(function.tag, function) - if is_web_endpoint: + if function._is_web_endpoint: self._web_endpoints.append(function.tag) + if function._method_functions: + for method_function in function._method_functions.values(): + if method_function._is_web_endpoint: + self._web_endpoints.append(method_function.tag) def _init_container(self, client: _Client, running_app: RunningApp): self._app_id = running_app.app_id @@ -815,7 +819,7 @@ def f(self, x): cluster_size=cluster_size, # Experimental: Clustered functions ) - self._add_function(function, webhook_config is not None) + self._add_function(function) return function @@ -944,7 +948,7 @@ def wrapper(user_cls: CLS_T) -> CLS_T: _experimental_custom_scaling_factor=_experimental_custom_scaling_factor, ) - self._add_function(cls_func, is_web_endpoint=False) + self._add_function(cls_func) cls: _Cls = _Cls.from_local(user_cls, self, cls_func) diff --git a/modal/cls.py b/modal/cls.py index 8b37a27fc..49cdbc65e 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -327,7 +327,7 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio for method_name, partial_function in partial_functions.items(): method_function = class_service_function._bind_method_old(user_cls, method_name, partial_function) - app._add_function(method_function, is_web_endpoint=partial_function.webhook_config is not None) + app._add_function(method_function) partial_function.wrapped = True functions[method_name] = method_function diff --git a/modal/functions.py b/modal/functions.py index c94dbef10..5d06708fa 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -317,6 +317,9 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _build_args: dict _can_use_base_function: bool = False # whether we need to call FunctionBindParams _is_generator: Optional[bool] = None + _is_web_endpoint: Optional[ + bool + ] = None # used to determine whether an un-hydrated function is a web endpoint, since web_url won't be populated _cluster_size: Optional[int] = None # when this is the method of a class/object function, invocation of this function @@ -360,6 +363,7 @@ def _bind_method( fun._use_method_name = method_name fun._app = class_service_function._app fun._is_generator = partial_function.is_generator + fun._is_web_endpoint = partial_function.webhook_config is not None fun._cluster_size = partial_function.cluster_size fun._spec = class_service_function._spec fun._is_method = True @@ -448,6 +452,7 @@ def _deps(): fun._use_method_name = method_name fun._app = class_service_function._app fun._is_generator = partial_function.is_generator + fun._is_web_endpoint = partial_function.webhook_config is not None fun._cluster_size = partial_function.cluster_size fun._spec = class_service_function._spec fun._is_method = True @@ -478,6 +483,7 @@ def hydrate_from_instance_service_function(method_placeholder_fun): ) # TODO: this shouldn't be set when actual parameters are used method_placeholder_fun._function_name = full_function_name method_placeholder_fun._is_generator = class_bound_method._is_generator + method_placeholder_fun._is_web_endpoint = class_bound_method._is_web_endpoint method_placeholder_fun._cluster_size = class_bound_method._cluster_size method_placeholder_fun._use_method_name = method_name method_placeholder_fun._use_function_id = instance_service_function.object_id @@ -1004,6 +1010,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._app = app # needed for CLI right now obj._obj = None obj._is_generator = is_generator + obj._is_web_endpoint = bool(webhook_config) obj._cluster_size = cluster_size obj._is_method = False obj._spec = function_spec # needed for modal shell @@ -1255,6 +1262,7 @@ def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata) self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR + self._is_web_endpoint = bool(metadata.web_url) self._web_url = metadata.web_url self._function_name = metadata.function_name self._is_method = metadata.is_method From 6dc8a727e957404b288742f44c40cb40785748d9 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 23:26:28 +0000 Subject: [PATCH 51/96] dont modify _add_function --- modal/app.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modal/app.py b/modal/app.py index da6576a56..3da26bb4f 100644 --- a/modal/app.py +++ b/modal/app.py @@ -509,10 +509,6 @@ def _add_function(self, function: _Function): self._add_object(function.tag, function) if function._is_web_endpoint: self._web_endpoints.append(function.tag) - if function._method_functions: - for method_function in function._method_functions.values(): - if method_function._is_web_endpoint: - self._web_endpoints.append(method_function.tag) def _init_container(self, client: _Client, running_app: RunningApp): self._app_id = running_app.app_id From c260a93932fb8c54f1b1fbd20a7be6698624d287 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 23:27:58 +0000 Subject: [PATCH 52/96] add web endpoints for methods --- modal/app.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modal/app.py b/modal/app.py index 3da26bb4f..da6576a56 100644 --- a/modal/app.py +++ b/modal/app.py @@ -509,6 +509,10 @@ def _add_function(self, function: _Function): self._add_object(function.tag, function) if function._is_web_endpoint: self._web_endpoints.append(function.tag) + if function._method_functions: + for method_function in function._method_functions.values(): + if method_function._is_web_endpoint: + self._web_endpoints.append(method_function.tag) def _init_container(self, client: _Client, running_app: RunningApp): self._app_id = running_app.app_id From 3886d513ff35bb17af96c87e22513da6227cf6ea Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Fri, 15 Nov 2024 23:50:20 +0000 Subject: [PATCH 53/96] fix bug in mock FunctionPrecreate --- test/cls_test.py | 2 +- test/conftest.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/cls_test.py b/test/cls_test.py index aedc6d11f..323d298c8 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -510,7 +510,7 @@ def bar(self): def test_method_args(servicer, client): with app_method_args.run(client=client): funcs = servicer.app_functions.values() - assert {f.function_name for f in funcs} == {"XYZ.*", "XYZ.foo", "XYZ.bar"} + assert {f.function_name for f in funcs} == {"XYZ.*"} warm_pools = {f.function_name: f.warm_pool_size for f in funcs} assert warm_pools["XYZ.*"] == 5 del warm_pools["XYZ.*"] diff --git a/test/conftest.py b/test/conftest.py index ae3eb13b3..6eea63997 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -905,7 +905,11 @@ async def FunctionPrecreate(self, stream): # This loop is for class service functions, where req.method_definitions will be non-empty method_handle_metadata: dict[str, api_pb2.FunctionHandleMetadata] = {} for method_name, method_definition in req.method_definitions.items(): - method_web_url = f"https://{method_name}.internal" + method_web_url = ( + f"https://{method_name}.internal" + if method_definition.HasField("webhook_config") and method_definition.webhook_config.type + else None + ) method_handle_metadata[method_name] = api_pb2.FunctionHandleMetadata( function_name=method_definition.function_name, function_type=method_definition.function_type, From d003af165dad0651ee5418babb4ec3f9d1e4a568 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 17 Nov 2024 20:11:09 +0000 Subject: [PATCH 54/96] fix method_functions lookup for build function --- modal/_container_entrypoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index ab7c81f42..723913127 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -559,7 +559,7 @@ def import_single_function_service( # The cls decorator is in global scope _cls = synchronizer._translate_in(cls) user_defined_callable = _cls._callables[fun_name] - function = _cls._method_functions.get(fun_name) + function = _cls._class_service_function._method_functions.get(fun_name) active_app = _cls._app else: # This is a raw class From 7f6d47551e872a7a7d6870cf19db56a0d2a3de4e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sun, 17 Nov 2024 22:45:27 +0000 Subject: [PATCH 55/96] fix tests --- modal/cls.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modal/cls.py b/modal/cls.py index d63294ee4..7c9b54b1c 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -102,7 +102,7 @@ def __init__( check_valid_cls_constructor_arg(key, kwarg) self._method_functions = {} - if class_service_function._is_hydrated: + if not getattr(class_service_function, "_fake", None): # >= v0.63 classes # first create the singular object function used by all methods on this parameterization self._instance_service_function = class_service_function._bind_parameters( @@ -285,6 +285,7 @@ async def _load(): rep = "Function(class_function)" self._class_service_function = _Function._from_loader(_load, rep) + self._class_service_function._fake = True for method in metadata.methods: self._class_service_function._method_functions[method.function_name] = _Function._new_hydrated( method.function_id, self._client, method.function_handle_metadata From d6640557e3605bf41174842634bf24f964b15823 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 00:35:06 +0000 Subject: [PATCH 56/96] more fixes --- modal/functions.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 3e50bd52c..3ebce8f27 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -433,17 +433,15 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate_function_and_method_functions( - self, function_id: str, client: _Client, handle_metadata: api_pb2.FunctionHandleMetadata - ): - self._hydrate(function_id, client, handle_metadata) + def _hydrate(self, function_id: str, client: _Client, handle_metadata: api_pb2.FunctionHandleMetadata): + super()._hydrate(function_id, client, handle_metadata) if self._method_functions: # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): if method_name in self._method_functions: method_function = self._method_functions[method_name] - method_function._hydrate(function_id, client, method_handle_metadata) + super(_Function, method_function)._hydrate(function_id, client, method_handle_metadata) elif len(handle_metadata.method_handle_metadata): # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based # on the existence of method_handle_metadata that this is a class service function @@ -713,7 +711,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti elif webhook_config: req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub @@ -911,7 +909,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona local_mounts = set(m for m in all_mounts if m.is_local()) # needed for modal.serve file watching local_mounts |= image._used_local_mounts obj._used_local_mounts = frozenset(local_mounts) - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) @@ -1089,7 +1087,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: else: raise - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) rep = f"Ref({tag})" return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True) From 1e62bd907701ede273a4625dae17d8888be794df Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 00:39:59 +0000 Subject: [PATCH 57/96] fix issue --- modal/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index 3ebce8f27..dd0a7590e 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -433,7 +433,8 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate(self, function_id: str, client: _Client, handle_metadata: api_pb2.FunctionHandleMetadata): + def _hydrate(self, function_id: str, client: _Client, handle_metadata: Optional[Message]): + assert isinstance(handle_metadata, api_pb2.FunctionHandleMetadata) super()._hydrate(function_id, client, handle_metadata) if self._method_functions: # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a From 4080e578ee4fe27b716689ed41845e9841731ef9 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 00:44:44 +0000 Subject: [PATCH 58/96] override properly --- modal/functions.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index dd0a7590e..6a26aadc0 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -433,24 +433,22 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate(self, function_id: str, client: _Client, handle_metadata: Optional[Message]): + def _hydrate(self, object_id: str, client: _Client, handle_metadata: Optional[Message]): assert isinstance(handle_metadata, api_pb2.FunctionHandleMetadata) - super()._hydrate(function_id, client, handle_metadata) + super()._hydrate(object_id, client, handle_metadata) if self._method_functions: # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): if method_name in self._method_functions: method_function = self._method_functions[method_name] - super(_Function, method_function)._hydrate(function_id, client, method_handle_metadata) + super(_Function, method_function)._hydrate(object_id, client, method_handle_metadata) elif len(handle_metadata.method_handle_metadata): # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based # on the existence of method_handle_metadata that this is a class service function self._method_functions = {} for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): - self._method_functions[method_name] = _Function._new_hydrated( - function_id, client, method_handle_metadata - ) + self._method_functions[method_name] = _Function._new_hydrated(object_id, client, method_handle_metadata) @staticmethod def from_args( From 22eb509786509624dc7b8849ac6abc7bdbd411ea Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 00:49:26 +0000 Subject: [PATCH 59/96] fix parameter name --- modal/functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 6a26aadc0..10eae2c3b 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -433,21 +433,21 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate(self, object_id: str, client: _Client, handle_metadata: Optional[Message]): - assert isinstance(handle_metadata, api_pb2.FunctionHandleMetadata) - super()._hydrate(object_id, client, handle_metadata) + def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]): + assert isinstance(metadata, api_pb2.FunctionHandleMetadata) + super()._hydrate(object_id, client, metadata) if self._method_functions: # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects - for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): + for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): if method_name in self._method_functions: method_function = self._method_functions[method_name] super(_Function, method_function)._hydrate(object_id, client, method_handle_metadata) - elif len(handle_metadata.method_handle_metadata): + elif len(metadata.method_handle_metadata): # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based # on the existence of method_handle_metadata that this is a class service function self._method_functions = {} - for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): + for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): self._method_functions[method_name] = _Function._new_hydrated(object_id, client, method_handle_metadata) @staticmethod From 544f5bcc159e8484fcf08b02a351c63654cfde02 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 02:39:20 +0000 Subject: [PATCH 60/96] more fixes --- test/conftest.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 6eea63997..3a21b8f1a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -341,26 +341,24 @@ def get_function_metadata(self, object_id: str) -> api_pb2.FunctionHandleMetadat is_method=definition.is_method, use_method_name=definition.use_method_name, use_function_id=definition.use_function_id, - ) - - def get_class_metadata(self, object_id: str) -> api_pb2.ClassHandleMetadata: - class_handle_metadata = api_pb2.ClassHandleMetadata() - for f_name, f_id in self.classes[object_id].items(): - function_handle_metadata = self.get_function_metadata(f_id) - class_handle_metadata.methods.append( - api_pb2.ClassMethod( - function_name=f_name, function_id=f_id, function_handle_metadata=function_handle_metadata + method_handle_metadata={ + method_name: api_pb2.FunctionHandleMetadata( + function_name=method_definition.function_name, + function_type=method_definition.function_type, + web_url=method_definition.web_url, + is_method=True, + use_method_name=method_name, ) - ) - - return class_handle_metadata + for method_name, method_definition in definition.method_definitions.items() + }, + ) def get_object_metadata(self, object_id) -> api_pb2.Object: if object_id.startswith("fu-"): res = api_pb2.Object(function_handle_metadata=self.get_function_metadata(object_id)) elif object_id.startswith("cs-"): - res = api_pb2.Object(class_handle_metadata=self.get_class_metadata(object_id)) + res = api_pb2.Object(class_handle_metadata=api_pb2.ClassHandleMetadata()) elif object_id.startswith("mo-"): mount_handle_metadata = api_pb2.MountHandleMetadata(content_checksum_sha256_hex="abc123") @@ -639,7 +637,7 @@ async def ClassCreate(self, stream): class_id = "cs-" + str(len(self.classes)) self.classes[class_id] = methods await stream.send_message( - api_pb2.ClassCreateResponse(class_id=class_id, handle_metadata=self.get_class_metadata(class_id)) + api_pb2.ClassCreateResponse(class_id=class_id, handle_metadata=api_pb2.ClassHandleMetadata()) ) async def ClassGet(self, stream): @@ -650,7 +648,7 @@ async def ClassGet(self, stream): if object_id is None: raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}") await stream.send_message( - api_pb2.ClassGetResponse(class_id=object_id, handle_metadata=self.get_class_metadata(object_id)) + api_pb2.ClassGetResponse(class_id=object_id, handle_metadata=api_pb2.ClassHandleMetadata()) ) ### Client From 5d50cd1cc285619ec498c1824d19da1f3af07254 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 02:53:53 +0000 Subject: [PATCH 61/96] more test fixes --- test/cls_test.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/test/cls_test.py b/test/cls_test.py index 323d298c8..68a2d8a2a 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -57,7 +57,7 @@ def test_run_class(client, servicer): app_id = app.app_id objects = servicer.app_objects[app_id] - assert len(objects) == 3 # the class + two functions (one method-bound and one for the class) + assert len(objects) == 2 # the class + the class service function assert objects["Foo.bar"] == method_id assert objects["Foo"] == class_id class_function_id = objects["Foo.*"] @@ -72,9 +72,9 @@ def test_run_class(client, servicer): def test_call_class_sync(client, servicer): with servicer.intercept() as ctx: with app.run(client=client): - assert len(ctx.get_requests("FunctionCreate")) == 2 # one for base function, one for the method + assert len(ctx.get_requests("FunctionCreate")) == 1 # one for the class service function foo: Foo = Foo() - assert len(ctx.get_requests("FunctionCreate")) == 2 # no additional creates for an instance + assert len(ctx.get_requests("FunctionCreate")) == 1 # no additional creates for an instance ret: float = foo.bar.remote(42) assert ret == 1764 @@ -82,10 +82,10 @@ def test_call_class_sync(client, servicer): len(ctx.get_requests("FunctionBindParams")) == 0 ) # shouldn't need to bind in case there are no instance args etc. function_creates_requests: typing.List[api_pb2.FunctionCreateRequest] = ctx.get_requests("FunctionCreate") - assert len(function_creates_requests) == 2 + assert len(function_creates_requests) == 1 (class_create,) = ctx.get_requests("ClassCreate") function_creates = {fc.function.function_name: fc for fc in function_creates_requests} - assert function_creates.keys() == {"Foo.*", "Foo.bar"} + assert function_creates.keys() == {"Foo.*"} foobar_def = function_creates["Foo.bar"].function service_function_id = servicer.app_objects["ap-1"]["Foo.*"] assert foobar_def.is_method @@ -184,7 +184,7 @@ def bar(self, x): with app_ser.run(client=client): pass - assert servicer.n_functions == 2 + assert servicer.n_functions == 1 class_function = servicer.function_by_name("FooSer.*") assert class_function.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED user_cls = deserialize(class_function.class_serialized, client) @@ -512,9 +512,7 @@ def test_method_args(servicer, client): funcs = servicer.app_functions.values() assert {f.function_name for f in funcs} == {"XYZ.*"} warm_pools = {f.function_name: f.warm_pool_size for f in funcs} - assert warm_pools["XYZ.*"] == 5 - del warm_pools["XYZ.*"] - assert set(warm_pools.values()) == {0} # methods don't have warm pools themselves + assert warm_pools == {"XYZ.*": 5} def test_keep_warm_depr(client, set_env_client): @@ -550,25 +548,18 @@ def bar(self): ... with app.run(client=client): - assert len(servicer.app_functions) == 2 # class service function + method placeholder + assert len(servicer.app_functions) == 1 # only class service function cls_fun = servicer.function_by_name("ClsWithMethod.*") - method_placeholder_fun = servicer.function_by_name( - "ClsWithMethod.bar" - ) # there should be no containers at all for methods assert cls_fun.is_class - assert method_placeholder_fun.is_method assert cls_fun.warm_pool_size == 0 - assert method_placeholder_fun.warm_pool_size == 0 ClsWithMethod().keep_warm(2) # type: ignore # Python can't do type intersection assert cls_fun.warm_pool_size == 2 - assert method_placeholder_fun.warm_pool_size == 0 ClsWithMethod("other-instance").keep_warm(5) # type: ignore # Python can't do type intersection instance_service_function = servicer.function_by_name("ClsWithMethod.*", params=((("other-instance",), {}))) - assert len(servicer.app_functions) == 3 # + instance service function + assert len(servicer.app_functions) == 2 # + instance service function assert cls_fun.warm_pool_size == 2 - assert method_placeholder_fun.warm_pool_size == 0 assert instance_service_function.warm_pool_size == 5 @@ -628,8 +619,8 @@ def asgi(self): def test_web_cls(client): with web_app_app.run(client=client): c = WebCls() - assert c.endpoint.web_url == "http://xyz.internal" - assert c.asgi.web_url == "http://xyz.internal" + assert c.endpoint.web_url == "https://endpoint.internal" + assert c.asgi.web_url == "https://asgi.internal" handler_app = App("handler-app") From 9d2b9f2675b8e84e4f221e17ba7f981d43a65245 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 03:53:48 +0000 Subject: [PATCH 62/96] more fixes --- modal/cli/run.py | 7 +++++-- test/conftest.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/modal/cli/run.py b/modal/cli/run.py index 84457c105..87c06b876 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -136,7 +136,10 @@ def _get_clean_app_description(func_ref: str) -> str: def _get_click_command_for_function(app: App, function_tag): - function = app.indexed_objects[function_tag] + function = app.indexed_objects.get(function_tag, None) + if not function: + class_name, method_name = function_tag.rsplit(".", 1) + function = app.indexed_objects.get(f"{class_name}.*") assert isinstance(function, Function) function = typing.cast(Function, function) if function.is_generator: @@ -149,7 +152,7 @@ def _get_click_command_for_function(app: App, function_tag): class_name, method_name = function_tag.rsplit(".", 1) cls = typing.cast(Cls, app.indexed_objects[class_name]) cls_signature = _get_signature(function.info.user_cls) - fun_signature = _get_signature(function.info.raw_f, is_method=True) + fun_signature = _get_signature(getattr(cls, method_name).info.raw_f, is_method=True) signature = dict(**cls_signature, **fun_signature) # Pool all arguments # TODO(erikbern): assert there's no overlap? else: diff --git a/test/conftest.py b/test/conftest.py index 3a21b8f1a..9a24c5a6f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -971,6 +971,16 @@ async def FunctionCreate(self, stream): use_function_id=function_defn.use_function_id or function_id, use_method_name=function_defn.use_method_name, definition_id=f"de-{self.n_functions}", + method_handle_metadata={ + method_name: api_pb2.FunctionHandleMetadata( + function_name=method_definition.function_name, + function_type=method_definition.function_type, + web_url=method_definition.web_url, + is_method=True, + use_method_name=method_name, + ) + for method_name, method_definition in function_defn.method_definitions.items() + }, ), ) ) From f90b22d552ae76dc776a3e14cad606bf4668ce04 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 04:57:53 +0000 Subject: [PATCH 63/96] Have _Function override _Object _hydrate method --- modal/functions.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index 5d06708fa..6e4e88237 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -523,25 +523,22 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate_function_and_method_functions( - self, function_id: str, client: _Client, handle_metadata: api_pb2.FunctionHandleMetadata - ): - self._hydrate(function_id, client, handle_metadata) + def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]): + assert isinstance(metadata, api_pb2.FunctionHandleMetadata) + super()._hydrate(object_id, client, metadata) if self._method_functions: # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects - for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): + for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): if method_name in self._method_functions: method_function = self._method_functions[method_name] - method_function._hydrate(function_id, client, method_handle_metadata) - elif len(handle_metadata.method_handle_metadata): + super(_Function, method_function)._hydrate(object_id, client, method_handle_metadata) + elif len(metadata.method_handle_metadata): # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based # on the existence of method_handle_metadata that this is a class service function self._method_functions = {} - for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): - self._method_functions[method_name] = _Function._new_hydrated( - function_id, client, method_handle_metadata - ) + for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): + self._method_functions[method_name] = _Function._new_hydrated(object_id, client, method_handle_metadata) @staticmethod def from_args( @@ -803,7 +800,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti elif webhook_config: req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub @@ -999,7 +996,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona local_mounts = set(m for m in all_mounts if m.is_local()) # needed for modal.serve file watching local_mounts |= image._used_local_mounts obj._used_local_mounts = frozenset(local_mounts) - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) @@ -1177,7 +1174,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: else: raise - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) rep = f"Ref({app_name})" return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True) From d1e7d8a119095c16ed7284fa34944a5677b3e671 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:08:32 +0000 Subject: [PATCH 64/96] corrections --- modal/_container_entrypoint.py | 2 +- modal/cls.py | 75 +++++++++++++++++++--------------- modal/functions.py | 17 ++------ 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 723913127..ab7c81f42 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -559,7 +559,7 @@ def import_single_function_service( # The cls decorator is in global scope _cls = synchronizer._translate_in(cls) user_defined_callable = _cls._callables[fun_name] - function = _cls._class_service_function._method_functions.get(fun_name) + function = _cls._method_functions.get(fun_name) active_app = _cls._app else: # This is a raw class diff --git a/modal/cls.py b/modal/cls.py index 7c9b54b1c..7040cf6ce 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -90,7 +90,8 @@ def _uses_common_service_function(self): def __init__( self, user_cls: type, - class_service_function: Optional[_Function], # only hydrated for = v0.63 classes # first create the singular object function used by all methods on this parameterization self._instance_service_function = class_service_function._bind_parameters( self, from_other_workspace, options, args, kwargs ) - for method_name, class_bound_method in class_service_function._method_functions.items(): + for method_name, class_bound_method in classbound_methods.items(): method = self._instance_service_function._bind_instance_method(class_bound_method) self._method_functions[method_name] = method else: # =v0.63 + _method_functions: Dict[str, _Function] = None # Placeholder _Functions for each method _options: Optional[api_pb2.FunctionOptions] _callables: Dict[str, Callable[..., Any]] _from_other_workspace: Optional[bool] # Functions require FunctionBindParams before invocation. @@ -251,6 +253,7 @@ def _initialize_from_empty(self): def _initialize_from_other(self, other: "_Cls"): self._user_cls = other._user_cls self._class_service_function = other._class_service_function + self._method_functions = other._method_functions self._options = other._options self._callables = other._callables self._from_other_workspace = other._from_other_workspace @@ -262,34 +265,33 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) - if self._class_service_function: - if self._class_service_function._method_functions: - # The class only has a class service service function and no method placeholders. - return - else: - for method in metadata.methods: - if method.function_name in self._class_service_function._method_functions: - # This happens when the class is loaded locally - # since each function will already be a loaded dependency _Function - self._class_service_function._method_functions[method.function_name]._hydrate( - method.function_id, self._client, method.function_handle_metadata - ) - else: - self._class_service_function._method_functions[method.function_name] = _Function._new_hydrated( - method.function_id, self._client, method.function_handle_metadata - ) + if self._class_service_function and self._class_service_function._method_handle_metadata: + # The class only has a class service service function and no method placeholders. + for method_metadata in self._class_service_function._method_handle_metadata.values(): + if method_metadata.function_name in self._method_functions: + # This happens when the class is loaded locally + # since each function will already be a loaded dependency _Function + self._method_functions[method_metadata.function_name]._hydrate( + self._class_service_function.object_id, self._client, method_metadata + ) + else: + self._method_functions[method_metadata.function_name] = _Function._new_hydrated( + self._class_service_function.object_id, self._client, method_metadata + ) else: - # We are dealing with a pre 0.63 class that does not have a class service function and only method functions - async def _load(): - pass - - rep = "Function(class_function)" - self._class_service_function = _Function._from_loader(_load, rep) - self._class_service_function._fake = True + # Either a class with class service function and method placeholders or pre 0.63 class that does not have a + # class service function and only method functions for method in metadata.methods: - self._class_service_function._method_functions[method.function_name] = _Function._new_hydrated( - method.function_id, self._client, method.function_handle_metadata - ) + if method.function_name in self._method_functions: + # This happens when the class is loaded locally + # since each function will already be a loaded dependency _Function + self._method_functions[method.function_name]._hydrate( + method.function_id, self._client, method.function_handle_metadata + ) + else: + self._method_functions[method.function_name] = _Function._new_hydrated( + method.function_id, self._client, method.function_handle_metadata + ) def _get_metadata(self) -> api_pb2.ClassHandleMetadata: class_handle_metadata = api_pb2.ClassHandleMetadata() @@ -326,12 +328,17 @@ def from_local(user_cls, app: "modal.app._App", class_service_function: _Functio # validate signature _Cls.validate_construction_mechanism(user_cls) + method_functions: Dict[str, _Function] = {} partial_functions: Dict[str, _PartialFunction] = _find_partial_methods_for_user_cls( user_cls, _PartialFunctionFlags.FUNCTION ) - for partial_function in partial_functions.values(): + for method_name, partial_function in partial_functions.items(): + method_function = class_service_function._bind_method(user_cls, method_name, partial_function) + if partial_function.webhook_config is not None: + app._web_endpoints.append(method_function.tag) partial_function.wrapped = True + method_functions[method_name] = method_function # Disable the warning that these are not wrapped for partial_function in _find_partial_methods_for_user_cls(user_cls, ~_PartialFunctionFlags.FUNCTION).values(): @@ -357,6 +364,7 @@ async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[s cls._app = app cls._user_cls = user_cls cls._class_service_function = class_service_function + cls._method_functions = method_functions cls._callables = callables cls._from_other_workspace = False return cls @@ -517,6 +525,7 @@ def __call__(self, *args, **kwargs) -> _Obj: return _Obj( self._user_cls, self._class_service_function, + self._method_functions, self._from_other_workspace, self._options, args, @@ -525,8 +534,8 @@ def __call__(self, *args, **kwargs) -> _Obj: def __getattr__(self, k): # Used by CLI and container entrypoint - if k in self._class_service_function._method_functions: - return self._class_service_function._method_functions[k] + if k in self._method_functions: + return self._method_functions[k] return getattr(self._user_cls, k) diff --git a/modal/functions.py b/modal/functions.py index 10eae2c3b..e0f61913e 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -331,6 +331,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _parent: Optional["_Function"] = None _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None + _method_handle_metadata: Optional[Dict[str, "api_pb2.FunctionHandleMetadata"]] = None _method_functions: Optional[Dict[str, "_Function"]] = None # Placeholder _Functions for each method def _bind_method( @@ -1177,6 +1178,7 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info self._definition_id = metadata.definition_id + self._method_handle_metadata = metadata.method_handle_metadata def _get_metadata(self): # Overridden concrete implementation of base class method @@ -1189,20 +1191,7 @@ def _get_metadata(self): is_method=self._is_method, class_parameter_info=self._class_parameter_info, definition_id=self._definition_id, - method_handle_metadata={ - method_name: api_pb2.FunctionHandleMetadata( - function_name=method_function._function_name, - function_type=get_function_type(method_function._is_generator), - web_url=method_function._web_url or "", - is_method=method_function._is_method, - definition_id=method_function._definition_id, - use_method_name=method_function._use_method_name, - ) - for method_name, method_function in self._method_functions.items() - if method_function._function_name - } - if self._method_functions - else None, + method_handle_metadata=self._method_handle_metadata, ) def _check_no_web_url(self, fn_name: str): From 8a7c9f21a36e3a8992fb2739f0de985daee9e9c2 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:10:33 +0000 Subject: [PATCH 65/96] more fixes --- modal/app.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/modal/app.py b/modal/app.py index da6576a56..125540510 100644 --- a/modal/app.py +++ b/modal/app.py @@ -492,7 +492,7 @@ def _get_watch_mounts(self): return [m for m in all_mounts if m.is_local()] - def _add_function(self, function: _Function): + def _add_function(self, function: _Function, is_web_endpoint: bool): if function.tag in self._indexed_objects: old_function = self._indexed_objects[function.tag] if isinstance(old_function, _Function): @@ -507,12 +507,8 @@ def _add_function(self, function: _Function): logger.warning(f"Warning: tag {function.tag} exists but is overridden by function") self._add_object(function.tag, function) - if function._is_web_endpoint: + if is_web_endpoint: self._web_endpoints.append(function.tag) - if function._method_functions: - for method_function in function._method_functions.values(): - if method_function._is_web_endpoint: - self._web_endpoints.append(method_function.tag) def _init_container(self, client: _Client, running_app: RunningApp): self._app_id = running_app.app_id @@ -819,7 +815,7 @@ def f(self, x): cluster_size=cluster_size, # Experimental: Clustered functions ) - self._add_function(function) + self._add_function(function, webhook_config is not None) return function @@ -948,7 +944,7 @@ def wrapper(user_cls: CLS_T) -> CLS_T: _experimental_custom_scaling_factor=_experimental_custom_scaling_factor, ) - self._add_function(cls_func) + self._add_function(cls_func, is_web_endpoint=False) cls: _Cls = _Cls.from_local(user_cls, self, cls_func) From 7685d636d4c67d44b02418eee5ee82d05b8521bd Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:13:53 +0000 Subject: [PATCH 66/96] cleanup --- modal/functions.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index e0f61913e..f8d9c23c7 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -363,7 +363,6 @@ def _bind_method( fun._use_method_name = method_name fun._app = class_service_function._app fun._is_generator = partial_function.is_generator - fun._is_web_endpoint = partial_function.webhook_config is not None fun._cluster_size = partial_function.cluster_size fun._spec = class_service_function._spec fun._is_method = True @@ -394,7 +393,6 @@ def hydrate_from_instance_service_function(method_placeholder_fun): ) # TODO: this shouldn't be set when actual parameters are used method_placeholder_fun._function_name = full_function_name method_placeholder_fun._is_generator = class_bound_method._is_generator - method_placeholder_fun._is_web_endpoint = class_bound_method._is_web_endpoint method_placeholder_fun._cluster_size = class_bound_method._cluster_size method_placeholder_fun._use_method_name = method_name method_placeholder_fun._use_function_id = instance_service_function.object_id @@ -434,23 +432,6 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]): - assert isinstance(metadata, api_pb2.FunctionHandleMetadata) - super()._hydrate(object_id, client, metadata) - if self._method_functions: - # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a - # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects - for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): - if method_name in self._method_functions: - method_function = self._method_functions[method_name] - super(_Function, method_function)._hydrate(object_id, client, method_handle_metadata) - elif len(metadata.method_handle_metadata): - # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based - # on the existence of method_handle_metadata that this is a class service function - self._method_functions = {} - for method_name, method_handle_metadata in metadata.method_handle_metadata.items(): - self._method_functions[method_name] = _Function._new_hydrated(object_id, client, method_handle_metadata) - @staticmethod def from_args( info: FunctionInfo, @@ -920,7 +901,6 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._app = app # needed for CLI right now obj._obj = None obj._is_generator = is_generator - obj._is_web_endpoint = bool(webhook_config) obj._cluster_size = cluster_size obj._is_method = False obj._spec = function_spec # needed for modal shell @@ -1171,7 +1151,6 @@ def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata) self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR - self._is_web_endpoint = bool(metadata.web_url) self._web_url = metadata.web_url self._function_name = metadata.function_name self._is_method = metadata.is_method From 4efafa15af5da5a797ef49f67e828d347367d489 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:15:14 +0000 Subject: [PATCH 67/96] more cleanup --- modal/functions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index f8d9c23c7..cf73084f3 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -317,9 +317,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _build_args: dict _can_use_base_function: bool = False # whether we need to call FunctionBindParams _is_generator: Optional[bool] = None - _is_web_endpoint: Optional[ - bool - ] = None # used to determine whether an un-hydrated function is a web endpoint, since web_url won't be populated _cluster_size: Optional[int] = None # when this is the method of a class/object function, invocation of this function From 8400fb1f6cff48be1b04bd6b18129b7b089630fb Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:30:55 +0000 Subject: [PATCH 68/96] cleanup --- modal/cli/run.py | 2 -- modal/cls.py | 42 ++++++++++++++++++++++-------------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/modal/cli/run.py b/modal/cli/run.py index 87c06b876..42c420d10 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -147,9 +147,7 @@ def _get_click_command_for_function(app: App, function_tag): signature: Dict[str, ParameterMetadata] cls: Optional[Cls] = None - method_name: Optional[str] = None if function.info.user_cls is not None: - class_name, method_name = function_tag.rsplit(".", 1) cls = typing.cast(Cls, app.indexed_objects[class_name]) cls_signature = _get_signature(function.info.user_cls) fun_signature = _get_signature(getattr(cls, method_name).info.raw_f, is_method=True) diff --git a/modal/cls.py b/modal/cls.py index 7040cf6ce..b8aa54a69 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -265,33 +265,35 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) - if self._class_service_function and self._class_service_function._method_handle_metadata: + if self._class_service_function and len(self._class_service_function._method_handle_metadata): # The class only has a class service service function and no method placeholders. - for method_metadata in self._class_service_function._method_handle_metadata.values(): - if method_metadata.function_name in self._method_functions: - # This happens when the class is loaded locally - # since each function will already be a loaded dependency _Function - self._method_functions[method_metadata.function_name]._hydrate( - self._class_service_function.object_id, self._client, method_metadata + if self._method_functions: + # We're here when the Cls is loaded locally (e.g. _Cls.from_local) so the _method_functions mapping is + # populated with (un-hydrated) _Function objects + for ( + method_name, + method_handle_metadata, + ) in self._class_service_function._method_handle_metadata.values(): + self._method_functions[method_name]._hydrate( + self._class_service_function.object_id, self._client, method_handle_metadata ) - else: - self._method_functions[method_metadata.function_name] = _Function._new_hydrated( - self._class_service_function.object_id, self._client, method_metadata + else: + # We're here when the function is loaded remotely (e.g. _Cls.from_name) + self._method_functions = {} + for ( + method_name, + method_handle_metadata, + ) in self._class_service_function._method_handle_metadata.values(): + self._method_functions[method_name] = _Function._new_hydrated( + self._class_service_function.object_id, self._client, method_handle_metadata ) else: # Either a class with class service function and method placeholders or pre 0.63 class that does not have a # class service function and only method functions for method in metadata.methods: - if method.function_name in self._method_functions: - # This happens when the class is loaded locally - # since each function will already be a loaded dependency _Function - self._method_functions[method.function_name]._hydrate( - method.function_id, self._client, method.function_handle_metadata - ) - else: - self._method_functions[method.function_name] = _Function._new_hydrated( - method.function_id, self._client, method.function_handle_metadata - ) + self._method_functions[method.function_name] = _Function._new_hydrated( + method.function_id, self._client, method.function_handle_metadata + ) def _get_metadata(self) -> api_pb2.ClassHandleMetadata: class_handle_metadata = api_pb2.ClassHandleMetadata() From f4ff8cd54ebcea898383210047d47d7e5f2e7069 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:35:15 +0000 Subject: [PATCH 69/96] cleanup --- modal/cls.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index b8aa54a69..ba18d41b8 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -287,9 +287,14 @@ def _hydrate_metadata(self, metadata: Message): self._method_functions[method_name] = _Function._new_hydrated( self._class_service_function.object_id, self._client, method_handle_metadata ) + elif self._class_service_function: + # A class with a class service function and method placeholder functions + for method in metadata.methods: + self._method_functions[method.function_name] = _Function._new_hydrated( + self._class_service_function.object_id, self._client, method.function_handle_metadata + ) else: - # Either a class with class service function and method placeholders or pre 0.63 class that does not have a - # class service function and only method functions + # pre 0.63 class that does not have a class service function and only method functions for method in metadata.methods: self._method_functions[method.function_name] = _Function._new_hydrated( method.function_id, self._client, method.function_handle_metadata From 2627cb3f774f2c1c922b2c26510e6dddb43d3478 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 17:43:12 +0000 Subject: [PATCH 70/96] fix type error --- modal/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index cf73084f3..7601bcab4 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1154,7 +1154,7 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info self._definition_id = metadata.definition_id - self._method_handle_metadata = metadata.method_handle_metadata + self._method_handle_metadata = dict(metadata.method_handle_metadata) def _get_metadata(self): # Overridden concrete implementation of base class method From 85699d30b7829317f6e47cc930f75cabb5c56de7 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 18:18:41 +0000 Subject: [PATCH 71/96] values() -> items() --- modal/cls.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index ba18d41b8..cdfa5e7e8 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -273,7 +273,7 @@ def _hydrate_metadata(self, metadata: Message): for ( method_name, method_handle_metadata, - ) in self._class_service_function._method_handle_metadata.values(): + ) in self._class_service_function._method_handle_metadata.items(): self._method_functions[method_name]._hydrate( self._class_service_function.object_id, self._client, method_handle_metadata ) @@ -283,7 +283,7 @@ def _hydrate_metadata(self, metadata: Message): for ( method_name, method_handle_metadata, - ) in self._class_service_function._method_handle_metadata.values(): + ) in self._class_service_function._method_handle_metadata.items(): self._method_functions[method_name] = _Function._new_hydrated( self._class_service_function.object_id, self._client, method_handle_metadata ) From cc7496e077f8f0545a8ca8c412aea8fbfdf65792 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 18:53:22 +0000 Subject: [PATCH 72/96] Fix mock FunctionPrecreate --- test/conftest.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index ae3eb13b3..fe95a22c2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -905,7 +905,11 @@ async def FunctionPrecreate(self, stream): # This loop is for class service functions, where req.method_definitions will be non-empty method_handle_metadata: dict[str, api_pb2.FunctionHandleMetadata] = {} for method_name, method_definition in req.method_definitions.items(): - method_web_url = f"https://{method_name}.internal" + method_web_url = ( + f"http://{method_name}.internal" + if method_definition.HasField("webhook_config") and method_definition.webhook_config.type + else None + ) method_handle_metadata[method_name] = api_pb2.FunctionHandleMetadata( function_name=method_definition.function_name, function_type=method_definition.function_type, From a5c93b96277e5caf781eac27b6feaf4bb1aa2faf Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 19:05:01 +0000 Subject: [PATCH 73/96] fix test_web_cls --- test/cls_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cls_test.py b/test/cls_test.py index 68a2d8a2a..d2b76346d 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -619,8 +619,8 @@ def asgi(self): def test_web_cls(client): with web_app_app.run(client=client): c = WebCls() - assert c.endpoint.web_url == "https://endpoint.internal" - assert c.asgi.web_url == "https://asgi.internal" + assert c.endpoint.web_url == "http://endpoint.internal" + assert c.asgi.web_url == "http://asgi.internal" handler_app = App("handler-app") From 106c711139f47b95e0b90f34429f13c5676830dc Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 19:19:10 +0000 Subject: [PATCH 74/96] another fix --- test/conftest.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 39328da85..5a9dbb739 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -943,14 +943,10 @@ async def FunctionCreate(self, stream): if len(request.function_data.ranked_functions) > 0: function_data = api_pb2.FunctionData() function_data.CopyFrom(request.function_data) - if function_data.webhook_config.type: - function_data.web_url = "http://xyz.internal" else: assert request.function function = api_pb2.Function() function.CopyFrom(request.function) - if function.webhook_config.type: - function.web_url = "http://xyz.internal" assert (function is None) != (function_data is None) function_defn = function or function_data From 504be62c4c730d980a1f4228a7c4c00c199a5619 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 19:27:55 +0000 Subject: [PATCH 75/96] Bring mock FunctionCreate to parity with server --- test/conftest.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index fe95a22c2..2114788cc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -341,6 +341,16 @@ def get_function_metadata(self, object_id: str) -> api_pb2.FunctionHandleMetadat is_method=definition.is_method, use_method_name=definition.use_method_name, use_function_id=definition.use_function_id, + method_handle_metadata={ + method_name: api_pb2.FunctionHandleMetadata( + function_name=method_definition.function_name, + function_type=method_definition.function_type, + web_url=method_definition.web_url, + is_method=True, + use_method_name=method_name, + ) + for method_name, method_definition in definition.method_definitions.items() + }, ) def get_class_metadata(self, object_id: str) -> api_pb2.ClassHandleMetadata: @@ -938,30 +948,28 @@ async def FunctionCreate(self, stream): else: self.n_functions += 1 function_id = f"fu-{self.n_functions}" - function: Optional[api_pb2.Function] = None function_data: Optional[api_pb2.FunctionData] = None - if len(request.function_data.ranked_functions) > 0: function_data = api_pb2.FunctionData() function_data.CopyFrom(request.function_data) - if function_data.webhook_config.type: - function_data.web_url = "http://xyz.internal" else: assert request.function function = api_pb2.Function() function.CopyFrom(request.function) - if function.webhook_config.type: - function.web_url = "http://xyz.internal" assert (function is None) != (function_data is None) function_defn = function or function_data assert function_defn + if function_defn.webhook_config.type: + function_defn.web_url = "http://xyz.internal" + for method_name, method_definition in function_defn.method_definitions.items(): + if method_definition.webhook_config.type: + method_definition.web_url = f"http://{method_name}.internal" self.app_functions[function_id] = function_defn if function_defn.schedule: self.function2schedule[function_id] = function_defn.schedule - await stream.send_message( api_pb2.FunctionCreateResponse( function_id=function_id, @@ -973,6 +981,16 @@ async def FunctionCreate(self, stream): use_function_id=function_defn.use_function_id or function_id, use_method_name=function_defn.use_method_name, definition_id=f"de-{self.n_functions}", + method_handle_metadata={ + method_name: api_pb2.FunctionHandleMetadata( + function_name=method_definition.function_name, + function_type=method_definition.function_type, + web_url=method_definition.web_url, + is_method=True, + use_method_name=method_name, + ) + for method_name, method_definition in function_defn.method_definitions.items() + }, ), ) ) From 527d912e1d374ce4e2a6f6e18426a76dbc0a4bc6 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 19:38:25 +0000 Subject: [PATCH 76/96] Add method_handle_metadata attr to _Function --- modal/functions.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index c94dbef10..ae0f34662 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -329,6 +329,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _parent: Optional["_Function"] = None _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None + _method_handle_metadata: Optional[Dict[str, "api_pb2.FunctionHandleMetadata"]] = None _method_functions: Optional[Dict[str, "_Function"]] = None # Placeholder _Functions for each method def _bind_method( @@ -1261,6 +1262,7 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._use_function_id = metadata.use_function_id self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info + self._method_handle_metadata = dict(metadata.method_handle_metadata) self._definition_id = metadata.definition_id def _invocation_function_id(self) -> str: @@ -1278,20 +1280,7 @@ def _get_metadata(self): is_method=self._is_method, class_parameter_info=self._class_parameter_info, definition_id=self._definition_id, - method_handle_metadata={ - method_name: api_pb2.FunctionHandleMetadata( - function_name=method_function._function_name, - function_type=get_function_type(method_function._is_generator), - web_url=method_function._web_url or "", - is_method=method_function._is_method, - definition_id=method_function._definition_id, - use_method_name=method_function._use_method_name, - ) - for method_name, method_function in self._method_functions.items() - if method_function._function_name - } - if self._method_functions - else None, + method_handle_metadata=self._method_handle_metadata, ) def _check_no_web_url(self, fn_name: str): From c0df22571896e432ec524d38c4c40ad906f20890 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 19:48:16 +0000 Subject: [PATCH 77/96] Remove _method_functions from _Function --- modal/functions.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/modal/functions.py b/modal/functions.py index ae0f34662..3a42aaf62 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -330,7 +330,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _class_parameter_info: Optional["api_pb2.ClassParameterInfo"] = None _method_handle_metadata: Optional[Dict[str, "api_pb2.FunctionHandleMetadata"]] = None - _method_functions: Optional[Dict[str, "_Function"]] = None # Placeholder _Functions for each method def _bind_method( self, @@ -518,26 +517,6 @@ def _deps(): fun._spec = class_bound_method._spec return fun - def _hydrate_function_and_method_functions( - self, function_id: str, client: _Client, handle_metadata: api_pb2.FunctionHandleMetadata - ): - self._hydrate(function_id, client, handle_metadata) - if self._method_functions: - # We're here when the function is loaded locally (e.g. _Function.from_args) and we're dealing with a - # class service function so the _method_functions mapping is populated with (un-hydrated) _Function objects - for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): - if method_name in self._method_functions: - method_function = self._method_functions[method_name] - method_function._hydrate(function_id, client, method_handle_metadata) - elif len(handle_metadata.method_handle_metadata): - # We're here when the function is loaded remotely (e.g. _Function.from_name) and we've determined based - # on the existence of method_handle_metadata that this is a class service function - self._method_functions = {} - for method_name, method_handle_metadata in handle_metadata.method_handle_metadata.items(): - self._method_functions[method_name] = _Function._new_hydrated( - function_id, client, method_handle_metadata - ) - @staticmethod def from_args( info: FunctionInfo, @@ -798,7 +777,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti elif webhook_config: req.webhook_config.CopyFrom(webhook_config) response = await retry_transient_errors(resolver.client.stub.FunctionPrecreate, req) - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub @@ -994,7 +973,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona local_mounts = set(m for m in all_mounts if m.is_local()) # needed for modal.serve file watching local_mounts |= image._used_local_mounts obj._used_local_mounts = frozenset(local_mounts) - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) rep = f"Function({tag})" obj = _Function._from_loader(_load, rep, preload=_preload, deps=_deps) @@ -1009,12 +988,6 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona obj._is_method = False obj._spec = function_spec # needed for modal shell - if info.user_cls: - obj._method_functions = {} - for method_name, partial_function in partial_functions.items(): - method_function = obj._bind_method(info.user_cls, method_name, partial_function) - obj._method_functions[method_name] = method_function - # Used to check whether we should rebuild a modal.Image which uses `run_function`. gpus: List[GPU_T] = gpu if isinstance(gpu, list) else [gpu] obj._build_args = dict( # See get_build_def @@ -1171,7 +1144,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: else: raise - self._hydrate_function_and_method_functions(response.function_id, resolver.client, response.handle_metadata) + self._hydrate(response.function_id, resolver.client, response.handle_metadata) rep = f"Ref({app_name})" return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True) From f3a819c43e789514b32f5b7bf346fa4de63dabcd Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 19:51:48 +0000 Subject: [PATCH 78/96] remove redundant line --- modal/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index b64fdd915..e3b1f2e24 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1148,7 +1148,6 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._class_parameter_info = metadata.class_parameter_info self._method_handle_metadata = dict(metadata.method_handle_metadata) self._definition_id = metadata.definition_id - self._method_handle_metadata = dict(metadata.method_handle_metadata) def _get_metadata(self): # Overridden concrete implementation of base class method From 49acb6cd93df0329d298172182640768f2da3ca2 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 22:26:26 +0000 Subject: [PATCH 79/96] fix modal shell maybe --- modal/cli/import_refs.py | 14 ++++++++++---- modal/cli/run.py | 6 +++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/modal/cli/import_refs.py b/modal/cli/import_refs.py index 05f92fbaf..373b5b4b5 100644 --- a/modal/cli/import_refs.py +++ b/modal/cli/import_refs.py @@ -137,9 +137,13 @@ def get_by_object_path_try_possible_app_names(obj: Any, obj_path: Optional[str]) def _infer_function_or_help( - app: App, module, accept_local_entrypoint: bool, accept_webhook: bool + app: App, module, accept_local_entrypoint: bool, accept_webhook: bool, accept_class_function: bool ) -> Union[Function, LocalEntrypoint]: - function_choices = set(tag for tag, func in app.registered_functions.items() if not func.info.is_service_class()) + function_choices = set( + tag + for tag, func in app.registered_functions.items() + if accept_class_function or not func.info.is_service_class() + ) if not accept_webhook: function_choices -= set(app.registered_web_endpoints) if accept_local_entrypoint: @@ -253,7 +257,7 @@ def foo(): def import_function( - func_ref: str, base_cmd: str, accept_local_entrypoint=True, accept_webhook=False + func_ref: str, base_cmd: str, accept_local_entrypoint=True, accept_webhook=False, accept_class_function=False ) -> Union[Function, LocalEntrypoint]: import_ref = parse_import_ref(func_ref) @@ -267,7 +271,9 @@ def import_function( if isinstance(app_or_function, App): # infer function or display help for how to select one app = app_or_function - function_handle = _infer_function_or_help(app, module, accept_local_entrypoint, accept_webhook) + function_handle = _infer_function_or_help( + app, module, accept_local_entrypoint, accept_webhook, accept_class_function + ) return function_handle elif isinstance(app_or_function, Function): return app_or_function diff --git a/modal/cli/run.py b/modal/cli/run.py index 42c420d10..c30dfd398 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -433,7 +433,11 @@ def shell( return function = import_function( - container_or_function, accept_local_entrypoint=False, accept_webhook=True, base_cmd="modal shell" + container_or_function, + accept_local_entrypoint=False, + accept_webhook=True, + accept_class_function=True, + base_cmd="modal shell", ) assert isinstance(function, Function) function_spec: _FunctionSpec = function.spec From 28c98d1e0e0e7a77e1127fda9887fb55634aa8b9 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 23:34:43 +0000 Subject: [PATCH 80/96] fix cli tests --- modal/cli/import_refs.py | 20 ++++++++++---------- modal/cli/run.py | 17 ++++++++++++++--- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/modal/cli/import_refs.py b/modal/cli/import_refs.py index 373b5b4b5..b64d2e84b 100644 --- a/modal/cli/import_refs.py +++ b/modal/cli/import_refs.py @@ -137,13 +137,12 @@ def get_by_object_path_try_possible_app_names(obj: Any, obj_path: Optional[str]) def _infer_function_or_help( - app: App, module, accept_local_entrypoint: bool, accept_webhook: bool, accept_class_function: bool + app: App, + module, + accept_local_entrypoint: bool, + accept_webhook: bool, ) -> Union[Function, LocalEntrypoint]: - function_choices = set( - tag - for tag, func in app.registered_functions.items() - if accept_class_function or not func.info.is_service_class() - ) + function_choices = set(tag for tag in app.registered_functions.keys()) if not accept_webhook: function_choices -= set(app.registered_web_endpoints) if accept_local_entrypoint: @@ -257,7 +256,10 @@ def foo(): def import_function( - func_ref: str, base_cmd: str, accept_local_entrypoint=True, accept_webhook=False, accept_class_function=False + func_ref: str, + base_cmd: str, + accept_local_entrypoint=True, + accept_webhook=False, ) -> Union[Function, LocalEntrypoint]: import_ref = parse_import_ref(func_ref) @@ -271,9 +273,7 @@ def import_function( if isinstance(app_or_function, App): # infer function or display help for how to select one app = app_or_function - function_handle = _infer_function_or_help( - app, module, accept_local_entrypoint, accept_webhook, accept_class_function - ) + function_handle = _infer_function_or_help(app, module, accept_local_entrypoint, accept_webhook) return function_handle elif isinstance(app_or_function, Function): return app_or_function diff --git a/modal/cli/run.py b/modal/cli/run.py index c30dfd398..faeecd95a 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -137,9 +137,12 @@ def _get_clean_app_description(func_ref: str) -> str: def _get_click_command_for_function(app: App, function_tag): function = app.indexed_objects.get(function_tag, None) - if not function: + if not function or function.info.user_cls is not None: + # This is either a function_tag for a class method function (e.g MyClass.foo) or a function tag for a + # class service function (MyClass.*) class_name, method_name = function_tag.rsplit(".", 1) - function = app.indexed_objects.get(f"{class_name}.*") + if not function: + function = app.indexed_objects.get(f"{class_name}.*") assert isinstance(function, Function) function = typing.cast(Function, function) if function.is_generator: @@ -150,6 +153,15 @@ def _get_click_command_for_function(app: App, function_tag): if function.info.user_cls is not None: cls = typing.cast(Cls, app.indexed_objects[class_name]) cls_signature = _get_signature(function.info.user_cls) + if method_name == "*": + method_names = list(cls._get_partial_functions().keys()) + if len(method_names) == 1: + method_name = method_names[0] + else: + class_name = function.info.user_cls.__name__ + raise click.UsageError( + f"Please specify a specific method of {class_name} to run, e.g. `modal run foo.py::MyClass.bar`" # noqa: E501 + ) fun_signature = _get_signature(getattr(cls, method_name).info.raw_f, is_method=True) signature = dict(**cls_signature, **fun_signature) # Pool all arguments # TODO(erikbern): assert there's no overlap? @@ -436,7 +448,6 @@ def shell( container_or_function, accept_local_entrypoint=False, accept_webhook=True, - accept_class_function=True, base_cmd="modal shell", ) assert isinstance(function, Function) From 049e14871b53cc8e9d08fbf725f2e6047c97c23e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 23:36:00 +0000 Subject: [PATCH 81/96] undo some unnecessary formatting --- modal/cli/import_refs.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/modal/cli/import_refs.py b/modal/cli/import_refs.py index b64d2e84b..4b8e092fc 100644 --- a/modal/cli/import_refs.py +++ b/modal/cli/import_refs.py @@ -137,10 +137,7 @@ def get_by_object_path_try_possible_app_names(obj: Any, obj_path: Optional[str]) def _infer_function_or_help( - app: App, - module, - accept_local_entrypoint: bool, - accept_webhook: bool, + app: App, module, accept_local_entrypoint: bool, accept_webhook: bool ) -> Union[Function, LocalEntrypoint]: function_choices = set(tag for tag in app.registered_functions.keys()) if not accept_webhook: @@ -256,10 +253,7 @@ def foo(): def import_function( - func_ref: str, - base_cmd: str, - accept_local_entrypoint=True, - accept_webhook=False, + func_ref: str, base_cmd: str, accept_local_entrypoint=True, accept_webhook=False ) -> Union[Function, LocalEntrypoint]: import_ref = parse_import_ref(func_ref) From b4dbed266c1678446265c79b0a8b41b04bc7521e Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 23:37:03 +0000 Subject: [PATCH 82/96] undo more unnecessary formatting --- modal/cli/run.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/modal/cli/run.py b/modal/cli/run.py index faeecd95a..2dae012a5 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -445,10 +445,7 @@ def shell( return function = import_function( - container_or_function, - accept_local_entrypoint=False, - accept_webhook=True, - base_cmd="modal shell", + container_or_function, accept_local_entrypoint=False, accept_webhook=True, base_cmd="modal shell" ) assert isinstance(function, Function) function_spec: _FunctionSpec = function.spec From 596095045c2439853a229414e2b17d40d3a41ab9 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 23:38:45 +0000 Subject: [PATCH 83/96] fix type check --- modal/cli/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/cli/run.py b/modal/cli/run.py index 2dae012a5..9557db511 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -137,7 +137,7 @@ def _get_clean_app_description(func_ref: str) -> str: def _get_click_command_for_function(app: App, function_tag): function = app.indexed_objects.get(function_tag, None) - if not function or function.info.user_cls is not None: + if not function or (isinstance(function, Function) and function.info.user_cls is not None): # This is either a function_tag for a class method function (e.g MyClass.foo) or a function tag for a # class service function (MyClass.*) class_name, method_name = function_tag.rsplit(".", 1) From da6453cb930d606976dc3b4ae8a759f85d02fe97 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 18 Nov 2024 23:47:28 +0000 Subject: [PATCH 84/96] fix tests --- test/cls_test.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/cls_test.py b/test/cls_test.py index d2b76346d..26be7e27c 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -58,14 +58,11 @@ def test_run_class(client, servicer): objects = servicer.app_objects[app_id] assert len(objects) == 2 # the class + the class service function - assert objects["Foo.bar"] == method_id assert objects["Foo"] == class_id class_function_id = objects["Foo.*"] assert class_function_id.startswith("fu-") - assert class_function_id != method_id + assert class_function_id == method_id - assert servicer.app_functions[method_id].use_function_id == class_function_id - assert servicer.app_functions[method_id].use_method_name == "bar" assert servicer.app_functions[class_function_id].is_class @@ -86,11 +83,7 @@ def test_call_class_sync(client, servicer): (class_create,) = ctx.get_requests("ClassCreate") function_creates = {fc.function.function_name: fc for fc in function_creates_requests} assert function_creates.keys() == {"Foo.*"} - foobar_def = function_creates["Foo.bar"].function service_function_id = servicer.app_objects["ap-1"]["Foo.*"] - assert foobar_def.is_method - assert foobar_def.use_method_name == "bar" - assert foobar_def.use_function_id == service_function_id (function_map_request,) = ctx.get_requests("FunctionMap") assert function_map_request.function_id == service_function_id From f43f42642d7cac4e7b107b393996cd9145c683ad Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 19 Nov 2024 15:56:11 +0000 Subject: [PATCH 85/96] Bump minor number --- modal_version/__init__.py | 2 +- modal_version/_version_generated.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modal_version/__init__.py b/modal_version/__init__.py index f0ac348bd..2e775e5f3 100644 --- a/modal_version/__init__.py +++ b/modal_version/__init__.py @@ -7,7 +7,7 @@ major_number = 0 # Bump this manually on breaking changes, then reset the number in _version_generated.py -minor_number = 66 +minor_number = 67 # Right now, automatically increment the patch number in CI __version__ = f"{major_number}.{minor_number}.{max(build_number, 0)}" diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index 776bda7f5..89250716b 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 13 # git: ed5ccfb +build_number = -1 # git: ed5ccfb From 0c6dd0325e616e00644598476538abcc2c32bd40 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Tue, 19 Nov 2024 16:21:23 +0000 Subject: [PATCH 86/96] fix bug in cls hydrate metadata --- modal/cls.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modal/cls.py b/modal/cls.py index 915ed46ce..009bb110e 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -265,7 +265,11 @@ def _get_partial_functions(self) -> Dict[str, _PartialFunction]: def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) - if self._class_service_function and len(self._class_service_function._method_handle_metadata): + if ( + self._class_service_function + and self._class_service_function._method_handle_metadata + and len(self._class_service_function._method_handle_metadata) + ): # The class only has a class service service function and no method placeholders. if self._method_functions: # We're here when the Cls is loaded locally (e.g. _Cls.from_local) so the _method_functions mapping is From ef527542c7b8ce24046aaec4d55a92ec7cd349e7 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 20 Nov 2024 20:35:22 +0000 Subject: [PATCH 87/96] dont set use_function_id on instance service function method function --- modal/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index 799d6b954..330e9bd42 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -391,7 +391,6 @@ def hydrate_from_instance_service_function(method_placeholder_fun): method_placeholder_fun._is_generator = class_bound_method._is_generator method_placeholder_fun._cluster_size = class_bound_method._cluster_size method_placeholder_fun._use_method_name = method_name - method_placeholder_fun._use_function_id = instance_service_function.object_id method_placeholder_fun._is_method = True async def _load(fun: "_Function", resolver: Resolver, existing_object_id: Optional[str]): @@ -932,6 +931,7 @@ def _bind_parameters( """ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): + print("_load for param function") if self._parent is None: raise ExecutionError("Can't find the parent class' service function") try: From 139d63ad4bcd22283a2716452f6173dd8fa9c5b7 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 20 Nov 2024 22:35:01 +0000 Subject: [PATCH 88/96] remove print statement --- modal/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modal/functions.py b/modal/functions.py index 330e9bd42..d4069f695 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -931,7 +931,6 @@ def _bind_parameters( """ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): - print("_load for param function") if self._parent is None: raise ExecutionError("Can't find the parent class' service function") try: From 6ff73354e5b772eab06b07eb23c77042baa849e0 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 21 Nov 2024 18:44:53 +0000 Subject: [PATCH 89/96] url display changes --- modal/_utils/function_utils.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index c2d84569b..b54847542 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -519,6 +519,16 @@ async def _create_input( ) +def _get_suffix_from_web_url_info(url_info: api_pb2.WebUrlInfo) -> str: + if url_info.truncated: + suffix = " [grey70](label truncated)[/grey70]" + elif url_info.label_stolen: + suffix = " [grey70](label stolen)[/grey70]" + else: + suffix = "" + return suffix + + class FunctionCreationStatus: # TODO(michael) this really belongs with other output-related code # but moving it here so we can use it when loading a function with output disabled @@ -547,12 +557,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): elif self.response.function.web_url: url_info = self.response.function.web_url_info # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc. - if url_info.truncated: - suffix = " [grey70](label truncated)[/grey70]" - elif url_info.label_stolen: - suffix = " [grey70](label stolen)[/grey70]" - else: - suffix = "" + suffix = _get_suffix_from_web_url_info(url_info) # TODO: this is only printed when we're showing progress. Maybe move this somewhere else. web_url = self.response.handle_metadata.web_url self.status_row.finish( @@ -568,3 +573,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) else: self.status_row.finish(f"Created function {self.tag}.") + if self.response.function.method_definitions_set: + for method_definition in self.response.function.method_definitions.values(): + if method_definition.web_url: + url_info = method_definition.web_url_info + suffix = _get_suffix_from_web_url_info(url_info) + class_web_endpoint_method_status_row = self.resolver.add_status_row() + class_web_endpoint_method_status_row.finish( + f"Created web endpoint for {method_definition.function_name}: [magenta underline]" + f"{method_definition.web_url}[/magenta underline]{suffix}" + ) + self.status_row.finish(f"Created function {self.tag}.") From 46b654a305b0c9cc993261405e885457631cadc0 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 21 Nov 2024 19:55:23 +0000 Subject: [PATCH 90/96] use arrow --- modal/_utils/function_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index b54847542..592ff8f4e 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -580,7 +580,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): suffix = _get_suffix_from_web_url_info(url_info) class_web_endpoint_method_status_row = self.resolver.add_status_row() class_web_endpoint_method_status_row.finish( - f"Created web endpoint for {method_definition.function_name}: [magenta underline]" + f"Created web endpoint for {method_definition.function_name} => [magenta underline]" f"{method_definition.web_url}[/magenta underline]{suffix}" ) self.status_row.finish(f"Created function {self.tag}.") From 16fde817ef0c9777716e367c98d7f8f59e30ace8 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 21 Nov 2024 20:33:01 +0000 Subject: [PATCH 91/96] undo some changes --- modal/cls.py | 2 +- modal/functions.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modal/cls.py b/modal/cls.py index 009bb110e..d23ab3a3a 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -442,7 +442,7 @@ async def _load_remote(obj: _Object, resolver: Resolver, existing_object_id: Opt obj._hydrate(response.class_id, resolver.client, response.handle_metadata) - rep = f"Ref({tag})" + rep = f"Ref({app_name})" cls = cls._from_loader(_load_remote, rep, is_another_app=True) cls._from_other_workspace = bool(workspace is not None) return cls diff --git a/modal/functions.py b/modal/functions.py index bcaee21d8..e31622f53 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1059,7 +1059,7 @@ async def _load_remote(self: _Function, resolver: Resolver, existing_object_id: self._hydrate(response.function_id, resolver.client, response.handle_metadata) - rep = f"Ref({tag})" + rep = f"Ref({app_name})" return cls._from_loader(_load_remote, rep, is_another_app=True, hydrate_lazily=True) @staticmethod @@ -1135,7 +1135,6 @@ def _initialize_from_empty(self): self._web_url = None self._function_name = None self._info = None - self._used_local_mounts = frozenset() self._serve_mounts = frozenset() def _hydrate_metadata(self, metadata: Optional[Message]): From 7508f4d1681959f9a870f0b82faf7d458ef0d1b6 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 21 Nov 2024 20:57:43 +0000 Subject: [PATCH 92/96] Show urls for modal domains and custom domains for web endpoint methods on new-style classes without method placeholders --- modal/_utils/function_utils.py | 36 ++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index c2d84569b..86aeb3c15 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -519,6 +519,16 @@ async def _create_input( ) +def _get_suffix_from_web_url_info(url_info: api_pb2.WebUrlInfo) -> str: + if url_info.truncated: + suffix = " [grey70](label truncated)[/grey70]" + elif url_info.label_stolen: + suffix = " [grey70](label stolen)[/grey70]" + else: + suffix = "" + return suffix + + class FunctionCreationStatus: # TODO(michael) this really belongs with other output-related code # but moving it here so we can use it when loading a function with output disabled @@ -547,12 +557,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): elif self.response.function.web_url: url_info = self.response.function.web_url_info # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc. - if url_info.truncated: - suffix = " [grey70](label truncated)[/grey70]" - elif url_info.label_stolen: - suffix = " [grey70](label stolen)[/grey70]" - else: - suffix = "" + suffix = _get_suffix_from_web_url_info(url_info) # TODO: this is only printed when we're showing progress. Maybe move this somewhere else. web_url = self.response.handle_metadata.web_url self.status_row.finish( @@ -563,8 +568,23 @@ def __exit__(self, exc_type, exc_val, exc_tb): for custom_domain in self.response.function.custom_domain_info: custom_domain_status_row = self.resolver.add_status_row() custom_domain_status_row.finish( - f"Custom domain for {self.tag} => [magenta underline]" - f"{custom_domain.url}[/magenta underline]{suffix}" + f"Custom domain for {self.tag} => [magenta underline]" f"{custom_domain.url}[/magenta underline]" ) else: self.status_row.finish(f"Created function {self.tag}.") + if self.response.function.method_definitions_set: + for method_definition in self.response.function.method_definitions.values(): + if method_definition.web_url: + url_info = method_definition.web_url_info + suffix = _get_suffix_from_web_url_info(url_info) + class_web_endpoint_method_status_row = self.resolver.add_status_row() + class_web_endpoint_method_status_row.finish( + f"Created web endpoint for {method_definition.function_name} => [magenta underline]" + f"{method_definition.web_url}[/magenta underline]{suffix}" + ) + for custom_domain in method_definition.custom_domain_info: + custom_domain_status_row = self.resolver.add_status_row() + custom_domain_status_row.finish( + f"Custom domain for {method_definition.function_name} => [magenta underline]" + f"{custom_domain.url}[/magenta underline]" + ) From ce79aeafadc5a15fcf80b7b9825b089dc93b03b6 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Thu, 21 Nov 2024 23:57:39 +0000 Subject: [PATCH 93/96] bug fix --- modal/cls.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modal/cls.py b/modal/cls.py index d23ab3a3a..af60dffb3 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -293,12 +293,14 @@ def _hydrate_metadata(self, metadata: Message): ) elif self._class_service_function: # A class with a class service function and method placeholder functions + self._method_functions = {} for method in metadata.methods: self._method_functions[method.function_name] = _Function._new_hydrated( self._class_service_function.object_id, self._client, method.function_handle_metadata ) else: # pre 0.63 class that does not have a class service function and only method functions + self._method_functions = {} for method in metadata.methods: self._method_functions[method.function_name] = _Function._new_hydrated( method.function_id, self._client, method.function_handle_metadata From fa9554cdcaba417d98c8bbfe4b36dda7d9ce7773 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sat, 23 Nov 2024 01:48:01 +0000 Subject: [PATCH 94/96] add test coverage --- test/cls_test.py | 17 +++++++++++++---- test/conftest.py | 13 ++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/test/cls_test.py b/test/cls_test.py index 794545ce0..3c76ecca2 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -49,21 +49,30 @@ def bar(self, x: int) -> float: def test_run_class(client, servicer): + assert len(servicer.precreated_functions) == 0 assert servicer.n_functions == 0 with app.run(client=client): - method_id = Foo.bar.object_id + method_handle_object_id = Foo.bar.object_id assert isinstance(Foo, Cls) class_id = Foo.object_id app_id = app.app_id + assert len(servicer.classes) == 1 and servicer.classes[0] == class_id + assert servicer.n_functions == 1 objects = servicer.app_objects[app_id] + class_function_id = objects["Foo.*"] + assert servicer.precreated_functions == {class_function_id} + assert method_handle_object_id == class_function_id assert len(objects) == 2 # the class + the class service function assert objects["Foo"] == class_id - class_function_id = objects["Foo.*"] assert class_function_id.startswith("fu-") - assert class_function_id == method_id - assert servicer.app_functions[class_function_id].is_class + assert servicer.app_functions[class_function_id].method_definitions == { + "bar": api_pb2.MethodDefinition( + function_name="Foo.bar", + function_type=api_pb2.Function.FunctionType.FUNCTION_TYPE_FUNCTION, + ) + } def test_call_class_sync(client, servicer): diff --git a/test/conftest.py b/test/conftest.py index 246d1bc9d..06801bb57 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -170,7 +170,7 @@ def __init__(self, blob_host, blobs, credentials): self.function_create_error: Optional[BaseException] = None self.heartbeat_status_code = None self.n_apps = 0 - self.classes = {} + self.classes = [] self.environments = {"main": "en-1"} self.task_result = None @@ -625,12 +625,9 @@ async def BlobGet(self, stream): async def ClassCreate(self, stream): request: api_pb2.ClassCreateRequest = await stream.recv_message() assert request.app_id - methods: dict[str, str] = {method.function_name: method.function_id for method in request.methods} class_id = "cs-" + str(len(self.classes)) - self.classes[class_id] = methods - await stream.send_message( - api_pb2.ClassCreateResponse(class_id=class_id, handle_metadata=api_pb2.ClassHandleMetadata()) - ) + self.classes.append(class_id) + await stream.send_message(api_pb2.ClassCreateResponse(class_id=class_id)) async def ClassGet(self, stream): request: api_pb2.ClassGetRequest = await stream.recv_message() @@ -639,9 +636,7 @@ async def ClassGet(self, stream): object_id = app_objects.get(request.object_tag) if object_id is None: raise GRPCError(Status.NOT_FOUND, f"can't find object {request.object_tag}") - await stream.send_message( - api_pb2.ClassGetResponse(class_id=object_id, handle_metadata=api_pb2.ClassHandleMetadata()) - ) + await stream.send_message(api_pb2.ClassGetResponse(class_id=object_id)) ### Client From 720aa559cfff44298a7549ecc2e50375b8006368 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Sat, 23 Nov 2024 01:48:52 +0000 Subject: [PATCH 95/96] edit build_number --- modal_version/_version_generated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index ec2152064..09436fb2e 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 39 # git: 892d93f +build_number = -1 # git: 892d93f From 1d606b69b83edb7bb1750f341e457d741d7ecb9f Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Wed, 27 Nov 2024 20:43:31 +0000 Subject: [PATCH 96/96] address nits --- modal/cli/import_refs.py | 2 +- modal/cli/run.py | 2 +- modal/cls.py | 8 ++------ modal_proto/api.proto | 6 +++--- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/modal/cli/import_refs.py b/modal/cli/import_refs.py index bca75c622..bfc0362f4 100644 --- a/modal/cli/import_refs.py +++ b/modal/cli/import_refs.py @@ -110,7 +110,7 @@ def get_by_object_path(obj: Any, obj_path: str) -> Optional[Any]: def _infer_function_or_help( app: App, module, accept_local_entrypoint: bool, accept_webhook: bool ) -> Union[Function, LocalEntrypoint]: - function_choices = set(tag for tag in app.registered_functions.keys()) + function_choices = set(app.registered_functions) if not accept_webhook: function_choices -= set(app.registered_web_endpoints) if accept_local_entrypoint: diff --git a/modal/cli/run.py b/modal/cli/run.py index 3a9671188..254491f17 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -136,7 +136,7 @@ def _get_clean_app_description(func_ref: str) -> str: def _get_click_command_for_function(app: App, function_tag): - function = app.registered_functions.get(function_tag, None) + function = app.registered_functions.get(function_tag) if not function or (isinstance(function, Function) and function.info.user_cls is not None): # This is either a function_tag for a class method function (e.g MyClass.foo) or a function tag for a # class service function (MyClass.*) diff --git a/modal/cls.py b/modal/cls.py index 8fafc2ebe..d55702be0 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -244,7 +244,7 @@ class _Cls(_Object, type_prefix="cs"): _class_service_function: Optional[ _Function ] # The _Function serving *all* methods of the class, used for version >=v0.63 - _method_functions: Dict[str, _Function] = None # Placeholder _Functions for each method + _method_functions: Optional[Dict[str, _Function]] = None # Placeholder _Functions for each method _options: Optional[api_pb2.FunctionOptions] _callables: Dict[str, Callable[..., Any]] _from_other_workspace: Optional[bool] # Functions require FunctionBindParams before invocation. @@ -277,7 +277,7 @@ def _hydrate_metadata(self, metadata: Message): and self._class_service_function._method_handle_metadata and len(self._class_service_function._method_handle_metadata) ): - # The class only has a class service service function and no method placeholders. + # The class only has a class service service function and no method placeholders (v0.67+) if self._method_functions: # We're here when the Cls is loaded locally (e.g. _Cls.from_local) so the _method_functions mapping is # populated with (un-hydrated) _Function objects @@ -313,10 +313,6 @@ def _hydrate_metadata(self, metadata: Message): method.function_id, self._client, method.function_handle_metadata ) - def _get_metadata(self) -> api_pb2.ClassHandleMetadata: - class_handle_metadata = api_pb2.ClassHandleMetadata() - return class_handle_metadata - @staticmethod def validate_construction_mechanism(user_cls): """mdmd:hidden""" diff --git a/modal_proto/api.proto b/modal_proto/api.proto index ce1f3816b..1e36ddbb1 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -339,7 +339,7 @@ message AppGetObjectsItem { message AppGetObjectsRequest { string app_id = 1; bool include_unindexed = 2; - bool only_class_function = 3; + bool only_class_function = 3; // True starting with 0.67.x clients, which don't create method placeholder functions } message AppGetObjectsResponse { @@ -596,7 +596,7 @@ message ClassCreateRequest { string existing_class_id = 2; repeated ClassMethod methods = 3; reserved 4; // removed class_function_id - bool only_class_function = 5; + bool only_class_function = 5; // True starting with 0.67.x clients, which don't create method placeholder functions } message ClassCreateResponse { @@ -612,7 +612,7 @@ message ClassGetRequest { bool lookup_published = 8; // Lookup class on app published by another workspace string workspace_name = 9; - bool only_class_function = 10; + bool only_class_function = 10; // True starting with 0.67.x clients, which don't create method placeholder functions } message ClassGetResponse {