Skip to content

Commit

Permalink
Unit test for finding handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Dec 18, 2023
1 parent 4d2f7cf commit 6469f1e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
6 changes: 6 additions & 0 deletions modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from .functions import (
Function,
asgi_app,
build,
current_function_call_id,
current_input_id,
enter,
exit,
method,
web_endpoint,
wsgi_app,
Expand Down Expand Up @@ -66,10 +69,13 @@
"Tunnel",
"Volume",
"asgi_app",
"build",
"container_app",
"create_package_mounts",
"current_function_call_id",
"current_input_id",
"enter",
"exit",
"forward",
"is_local",
"method",
Expand Down
6 changes: 6 additions & 0 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,8 @@ def wrapper(f: Union[Callable[[], Any], _PartialFunction]) -> _PartialFunction:
else:
return _PartialFunction(f, _PartialFunctionFlags.BUILD)

return wrapper


@typechecked
def _enter(_warn_parentheses_missing=None) -> Callable[[Union[Callable[[], Any], _PartialFunction]], _PartialFunction]:
Expand All @@ -1726,6 +1728,8 @@ def wrapper(f: Union[Callable[[], Any], _PartialFunction]) -> _PartialFunction:
else:
return _PartialFunction(f, _PartialFunctionFlags.ENTER)

return wrapper


# TODO(erikbern): last argument should be Optional[TracebackType]
ExitHandlerType = Callable[[Optional[Type[BaseException]], Optional[BaseException], Any], None]
Expand All @@ -1739,6 +1743,8 @@ def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _Partia
def wrapper(f: ExitHandlerType) -> _PartialFunction:
return _PartialFunction(f, _PartialFunctionFlags.EXIT)

return wrapper


method = synchronize_api(_method)
web_endpoint = synchronize_api(_web_endpoint)
Expand Down
35 changes: 34 additions & 1 deletion test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from typing_extensions import assert_type

from modal import Cls, Function, Stub, method
from modal import Cls, Dict, Function, Stub, build, enter, exit, method
from modal._serialization import deserialize
from modal.app import ContainerApp
from modal.cls import ClsMixin
from modal.exception import DeprecationError, ExecutionError
from modal.functions import _find_partial_methods, _PartialFunction, _PartialFunctionFlags
from modal.runner import deploy_stub
from modal_proto import api_pb2
from modal_test_support.base_class import BaseCls2
Expand Down Expand Up @@ -494,3 +495,35 @@ def test_keep_warm_depr():

with pytest.warns(DeprecationError, match="@method"):
stub.cls(keep_warm=2)(ClsWith2Methods)


class ClsWithHandlers:
@build()
def my_build(self):
pass

@enter()
def my_enter(self):
pass

@build()
@enter()
def my_build_and_enter(self):
pass

@exit()
def my_exit(self, exc_type, exc, traceback):
pass


def test_handlers():
pfs: Dict[str, _PartialFunction]

pfs = _find_partial_methods(ClsWithHandlers, _PartialFunctionFlags.BUILD)
assert list(pfs.keys()) == ["my_build", "my_build_and_enter"]

pfs = _find_partial_methods(ClsWithHandlers, _PartialFunctionFlags.ENTER)
assert list(pfs.keys()) == ["my_enter", "my_build_and_enter"]

pfs = _find_partial_methods(ClsWithHandlers, _PartialFunctionFlags.EXIT)
assert list(pfs.keys()) == ["my_exit"]

0 comments on commit 6469f1e

Please sign in to comment.