Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decorator-based lifecycle hooks #1136

Merged
merged 14 commits into from
Dec 21, 2023
6 changes: 6 additions & 0 deletions modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from .functions import (
Function,
asgi_app,
build,
current_function_call_id,
current_input_id,
enter,
exit,
method,
web_endpoint,
wsgi_app,
Expand Down Expand Up @@ -66,10 +69,13 @@
"Tunnel",
"Volume",
"asgi_app",
"build",
"container_app",
"create_package_mounts",
"current_function_call_id",
"current_input_id",
"enter",
"exit",
"forward",
"is_local",
"method",
Expand Down
84 changes: 53 additions & 31 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Callable, Optional, Type
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Type

from grpclib import Status

Expand All @@ -44,7 +44,13 @@
from .cls import Cls
from .config import config, logger
from .exception import InvalidError
from .functions import Function, _Function, _set_current_context_ids # type: ignore
from .functions import ( # type: ignore
Function,
_find_callables_for_obj,
_Function,
_PartialFunctionFlags,
_set_current_context_ids,
)

if TYPE_CHECKING:
from types import ModuleType
Expand Down Expand Up @@ -487,12 +493,13 @@ def call_function_sync(
):
# If this function is on a class, instantiate it and enter it
if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
if hasattr(imp_fun.obj, "__enter__"):
enter_methods: Dict[str, Callable] = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.ENTER)
for enter_method in enter_methods.values():
# Call a user-defined method
with function_io_manager.handle_user_exception():
imp_fun.obj.__enter__()
elif hasattr(imp_fun.obj, "__aenter__"):
logger.warning("Not running asynchronous enter/exit handlers with a sync function")
enter_res = enter_method()
if inspect.iscoroutine(enter_res):
logger.warning("Not running asynchronous enter/exit handlers with a sync function")

try:

Expand Down Expand Up @@ -537,9 +544,11 @@ def run_inputs(input_id: str, function_call_id: str, args: Any, kwargs: Any) ->
):
run_inputs(input_id, function_call_id, args, kwargs)
finally:
if imp_fun.obj is not None and hasattr(imp_fun.obj, "__exit__"):
with function_io_manager.handle_user_exception():
imp_fun.obj.__exit__(*sys.exc_info())
if imp_fun.obj is not None:
exit_methods: Dict[str, Callable] = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.EXIT)
for exit_method in exit_methods.values():
with function_io_manager.handle_user_exception():
exit_method(*sys.exc_info())


@wrap()
Expand All @@ -549,13 +558,13 @@ async def call_function_async(
):
# If this function is on a class, instantiate it and enter it
if imp_fun.obj is not None and not imp_fun.is_auto_snapshot:
if hasattr(imp_fun.obj, "__aenter__"):
enter_methods: Dict[str, Callable] = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.ENTER)
for enter_method in enter_methods.values():
# Call a user-defined method
async with function_io_manager.handle_user_exception.aio():
await imp_fun.obj.__aenter__()
elif hasattr(imp_fun.obj, "__enter__"):
async with function_io_manager.handle_user_exception.aio():
imp_fun.obj.__enter__()
with function_io_manager.handle_user_exception():
enter_res = enter_method()
if inspect.iscoroutine(enter_res):
await enter_res

try:

Expand Down Expand Up @@ -601,12 +610,13 @@ async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any
await run_input(input_id, function_call_id, args, kwargs)
finally:
if imp_fun.obj is not None:
if hasattr(imp_fun.obj, "__aexit__"):
async with function_io_manager.handle_user_exception.aio():
await imp_fun.obj.__aexit__(*sys.exc_info())
elif hasattr(imp_fun.obj, "__exit__"):
async with function_io_manager.handle_user_exception.aio():
imp_fun.obj.__exit__(*sys.exc_info())
exit_methods: Dict[str, Callable] = _find_callables_for_obj(imp_fun.obj, _PartialFunctionFlags.EXIT)
for exit_method in exit_methods.values():
# Call a user-defined method
with function_io_manager.handle_user_exception():
exit_res = exit_method(*sys.exc_info())
if inspect.iscoroutine(exit_res):
await exit_res


@dataclass
Expand Down Expand Up @@ -635,6 +645,8 @@ def import_function(
module: Optional[ModuleType] = None
cls: Optional[Type] = None
fun: Callable
function: Optional[_Function] = None
active_stub: Optional[_Stub] = None
pty_info: api_pb2.PTYInfo = function_def.pty_info

if pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
Expand All @@ -652,23 +664,33 @@ def import_function(

parts = qual_name.split(".")
if len(parts) == 1:
# This is a function
cls = None
fun = getattr(module, qual_name)
f = getattr(module, qual_name)
if isinstance(f, Function):
function = synchronizer._translate_in(f)
fun = function.get_raw_f()
active_stub = function._stub
else:
fun = f
elif len(parts) == 2:
# This is a method on a class
cls_name, fun_name = parts
cls = getattr(module, cls_name)
fun = getattr(cls, fun_name)
if isinstance(cls, Cls):
# The cls decorator is in global scope
_cls = synchronizer._translate_in(cls)
fun = _cls._callables[fun_name]
function = _cls._functions.get(fun_name)
active_stub = _cls._stub
else:
# This is a raw class
fun = getattr(cls, fun_name)
else:
raise InvalidError(f"Invalid function qualname {qual_name}")

# The decorator is typically in global scope, but may have been applied independently
active_stub: Optional[_Stub] = None
function: Optional[_Function] = None
if isinstance(fun, Function):
function = synchronizer._translate_in(fun)
fun = function.get_raw_f()
active_stub = function._stub
elif module is not None and not function_def.is_builder_function:
# If the cls/function decorator was applied in local scope, but the stub is global, we can look it up
if active_stub is None and function_def.stub_name:
# This branch is reached in the special case that the imported function is 1) not serialized, and 2) isn't a FunctionHandle - i.e, not decorated at definition time
# Look at all instantiated stubs - if there is only one with the indicated name, use that one
matching_stubs = _Stub._all_stubs.get(function_def.stub_name, [])
Expand Down
56 changes: 35 additions & 21 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from ._output import OutputManager
from ._resolver import Resolver
from .exception import deprecation_error
from .functions import _Function
from .functions import (
PartialFunction,
_find_callables_for_cls,
_find_partial_methods_for_cls,
_Function,
_PartialFunctionFlags,
)
from .object import _Object

T = TypeVar("T")
Expand Down Expand Up @@ -119,29 +125,31 @@ def __getattr__(self, k):
class _Cls(_Object, type_prefix="cs"):
_user_cls: Optional[type]
_functions: Dict[str, _Function]
_callables: Dict[str, Callable]

def _initialize_from_empty(self):
self._user_cls = None
self._base_functions = {}
self._functions = {}
self._callables = {}
self._output_mgr: Optional[OutputManager] = None

def _set_output_mgr(self, output_mgr: OutputManager):
self._output_mgr = output_mgr

def _hydrate_metadata(self, metadata: Message):
for method in metadata.methods:
if method.function_name in self._base_functions:
self._base_functions[method.function_name]._hydrate(
if method.function_name in self._functions:
self._functions[method.function_name]._hydrate(
method.function_id, self._client, method.function_handle_metadata
)
else:
self._base_functions[method.function_name] = _Function._new_hydrated(
self._functions[method.function_name] = _Function._new_hydrated(
method.function_id, self._client, method.function_handle_metadata
)

def _get_metadata(self) -> api_pb2.ClassHandleMetadata:
class_handle_metadata = api_pb2.ClassHandleMetadata()
for f_name, f in self._base_functions.items():
for f_name, f in self._functions.items():
class_handle_metadata.methods.append(
api_pb2.ClassMethod(
function_name=f_name, function_id=f.object_id, function_handle_metadata=f._get_metadata()
Expand All @@ -150,34 +158,40 @@ def _get_metadata(self) -> api_pb2.ClassHandleMetadata:
return class_handle_metadata

@staticmethod
def from_local(user_cls, base_functions: Dict[str, _Function]) -> "_Cls":
def from_local(user_cls, stub, decorator: Callable[[PartialFunction, type], _Function]) -> "_Cls":
functions: Dict[str, _Function] = {}
for k, partial_function in _find_partial_methods_for_cls(user_cls, _PartialFunctionFlags.FUNCTION).items():
functions[k] = decorator(partial_function, user_cls)

# Disable the warning that these are not wrapped
for partial_function in _find_partial_methods_for_cls(user_cls, ~_PartialFunctionFlags.FUNCTION).values():
partial_function.wrapped = True

# Get all callables
callables: Dict[str, Callable] = _find_callables_for_cls(user_cls, ~_PartialFunctionFlags(0))

def _deps() -> List[_Function]:
return list(base_functions.values())
return list(functions.values())

async def _load(provider: _Object, resolver: Resolver, existing_object_id: Optional[str]):
req = api_pb2.ClassCreateRequest(app_id=resolver.app_id, existing_class_id=existing_object_id)
for f_name, f in base_functions.items():
for f_name, f in functions.items():
req.methods.append(api_pb2.ClassMethod(function_name=f_name, function_id=f.object_id))
resp = await resolver.client.stub.ClassCreate(req)
provider._hydrate(resp.class_id, resolver.client, resp.handle_metadata)

rep = f"Cls({user_cls.__name__})"
cls = _Cls._from_loader(_load, rep, deps=_deps)
cls._stub = stub
cls._user_cls = user_cls
cls._base_functions = base_functions
setattr(cls._user_cls, "_modal_functions", base_functions) # Needed for PartialFunction.__get__
cls._functions = functions
cls._callables = callables
setattr(cls._user_cls, "_modal_functions", functions) # Needed for PartialFunction.__get__
return cls

def get_user_cls(self):
# Used by the container entrypoint
return self._user_cls

def get_base_function(self, k: str) -> _Function:
return self._base_functions[k]

def __call__(self, *args, **kwargs) -> _Obj:
"""This acts as the class constructor."""
return _Obj(self._user_cls, self._output_mgr, self._base_functions, args, kwargs)
return _Obj(self._user_cls, self._output_mgr, self._functions, args, kwargs)

async def remote(self, *args, **kwargs):
deprecation_error(
Expand All @@ -186,8 +200,8 @@ async def remote(self, *args, **kwargs):

def __getattr__(self, k):
# Used by CLI and container entrypoint
if k in self._base_functions:
return self._base_functions[k]
if k in self._functions:
return self._functions[k]
return getattr(self._user_cls, k)


Expand Down
Loading