Skip to content

Commit

Permalink
Unwrap functions when extracting global variable references
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Dec 4, 2024
1 parent 976d37e commit 7c12cf4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
6 changes: 6 additions & 0 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,13 @@ def get_cls_var_attrs(self) -> dict[str, Any]:
def get_globals(self) -> dict[str, Any]:
from .._vendor.cloudpickle import _extract_code_globals

if self.raw_f is None:
return {}

func = self.raw_f
while getattr(func, "__wrapped__"):
# Unwrap functions decorated using functools.wrapped (potentially multiple times)
func = func.__wrapped__ # type: ignore # We only get here if the attribute exists
f_globals_ref = _extract_code_globals(func.__code__)
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in func.__globals__}
return f_globals
Expand Down
22 changes: 22 additions & 0 deletions test/function_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright Modal Labs 2023
import functools
import pytest
import time

from grpclib import Status
Expand All @@ -14,6 +16,8 @@
)
from modal_proto import api_pb2

GLOBAL_VARIABLE = "whatever"


def hasarg(a):
...
Expand Down Expand Up @@ -133,3 +137,21 @@ async def test_stream_function_call_data(servicer, client):
assert 0.111 <= elapsed < 1.0

assert await gen.__anext__() == "world"


def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return f

return wrapper


def has_global_ref():
assert GLOBAL_VARIABLE


@pytest.mark.parametrize("func", [has_global_ref, decorator(has_global_ref)])
def test_global_variable_extraction(func):
info = FunctionInfo(func)
assert info.get_globals().get("GLOBAL_VARIABLE") == GLOBAL_VARIABLE

0 comments on commit 7c12cf4

Please sign in to comment.