diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index b738d69fa..8e305c5fd 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -248,7 +248,13 @@ def get_cls_var_attrs(self) -> dict[str, Any]: def get_globals(self) -> dict[str, Any]: from .._vendor.cloudpickle import _extract_code_globals + if self.raw_f is None: + return {} + func = self.raw_f + while hasattr(func, "__wrapped__") and func is not func.__wrapped__: + # Unwrap functions decorated using functools.wrapped (potentially multiple times) + func = func.__wrapped__ f_globals_ref = _extract_code_globals(func.__code__) f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in func.__globals__} return f_globals diff --git a/test/function_utils_test.py b/test/function_utils_test.py index 17926a5f2..e64571287 100644 --- a/test/function_utils_test.py +++ b/test/function_utils_test.py @@ -1,4 +1,6 @@ # Copyright Modal Labs 2023 +import functools +import pytest import time from grpclib import Status @@ -14,6 +16,8 @@ ) from modal_proto import api_pb2 +GLOBAL_VARIABLE = "whatever" + def hasarg(a): ... @@ -133,3 +137,21 @@ async def test_stream_function_call_data(servicer, client): assert 0.111 <= elapsed < 1.0 assert await gen.__anext__() == "world" + + +def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + return f + + return wrapper + + +def has_global_ref(): + assert GLOBAL_VARIABLE + + +@pytest.mark.parametrize("func", [has_global_ref, decorator(has_global_ref)]) +def test_global_variable_extraction(func): + info = FunctionInfo(func) + assert info.get_globals().get("GLOBAL_VARIABLE") == GLOBAL_VARIABLE