Skip to content

Commit

Permalink
fix bug with dependencies?
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Aug 4, 2024
1 parent 16cccf2 commit 6156aec
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 121 deletions.
3 changes: 2 additions & 1 deletion src/ell/decorators/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def decorator(
fn: LMP,
) -> InvocableLM:
color = compute_color(fn)
_under_fn = fn

@wraps(fn)
def wrapper(
Expand All @@ -48,7 +49,7 @@ def wrapper(

# TODO: # we'll deal with type safety here later
wrapper.__ell_lm_kwargs__ = lm_kwargs
wrapper.__ell_func__ = fn
wrapper.__ell_func__ = _under_fn
wrapper.__ell_lm = True
wrapper.__ell_exempt_from_tracking = exempt_from_tracking
if exempt_from_tracking:
Expand Down
117 changes: 98 additions & 19 deletions src/ell/util/closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def lexical_closure(
already_closed.add(hash(func))

globals_and_frees = _get_globals_and_frees(func)
dependencies, imports, modules = _process_dependencies(func, globals_and_frees, already_closed, recursion_stack)
dependencies, imports, modules = _process_dependencies(func, globals_and_frees, already_closed, recursion_stack, uses)

cur_src = _build_initial_source(imports, dependencies, source)

Expand All @@ -103,7 +103,7 @@ def lexical_closure(
fn_hash = _generate_function_hash(source, dsrc, func.__qualname__)

_update_ell_func(outer_ell_func, source, dsrc, globals_and_frees['globals'], globals_and_frees['frees'], fn_hash, uses)

return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))


Expand All @@ -117,7 +117,7 @@ def _format_source(source: str) -> str:

def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
"""Get global and free variables for a function."""
globals_dict = collections.OrderedDict(dill.detect.globalvars(func))
globals_dict = collections.OrderedDict(globalvars(func))
frees_dict = collections.OrderedDict(dill.detect.freevars(func))

if isinstance(func, type):
Expand All @@ -128,60 +128,71 @@ def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:

return {'globals': globals_dict, 'frees': frees_dict}

def _process_dependencies(func, globals_and_frees, already_closed, recursion_stack):
def _process_dependencies(func, globals_and_frees, already_closed, recursion_stack, uses):
"""Process function dependencies."""
dependencies = []
modules = deque()
imports = []

if isinstance(func, (types.FunctionType, types.MethodType)):
_process_default_kwargs(func, dependencies, already_closed, recursion_stack)
_process_default_kwargs(func, dependencies, already_closed, recursion_stack, uses)

for var_name, var_value in {**globals_and_frees['globals'], **globals_and_frees['frees']}.items():
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack)
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack, uses)

return dependencies, imports, modules

def _process_default_kwargs(func, dependencies, already_closed, recursion_stack):
def _process_default_kwargs(func, dependencies, already_closed, recursion_stack, uses):
"""Process default keyword arguments of a function."""
ps = inspect.signature(func).parameters
default_kwargs = collections.OrderedDict({k: v.default for k, v in ps.items() if v.default is not inspect.Parameter.empty})
for name, val in default_kwargs.items():
if name not in FORBIDDEN_NAMES:
try:
is_builtin = val.__class__.__module__ == "builtins" or val.__class__.__module__ == "__builtins__"
except:
is_builtin = False
if name not in FORBIDDEN_NAMES and not is_builtin:
try:
dep, _, _ = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
dep, _, _uses = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
dependencies.append(dep)
uses.update(_uses)
except Exception as e:
_raise_error(f"Failed to capture the lexical closure of default parameter {name}", e, recursion_stack)

def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack):
def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack , uses):
"""Process a single variable."""
if isinstance(var_value, (types.FunctionType, type, types.MethodType)):
_process_callable(var_name, var_value, dependencies, already_closed, recursion_stack)
_process_callable(var_name, var_value, dependencies, already_closed, recursion_stack, uses)
elif isinstance(var_value, types.ModuleType):
_process_module(var_name, var_value, modules, imports)
_process_module(var_name, var_value, modules, imports, uses)
elif isinstance(var_value, types.BuiltinFunctionType):
imports.append(dill.source.getimport(var_value, alias=var_name))
else:
_process_other_variable(var_name, var_value, dependencies)
_process_other_variable(var_name, var_value, dependencies, uses)

def _process_callable(var_name, var_value, dependencies, already_closed, recursion_stack):
def _process_callable(var_name, var_value, dependencies, already_closed, recursion_stack, uses):
"""Process a callable (function, method, or class)."""
if var_name not in FORBIDDEN_NAMES:
try:
module_is_ell = 'ell' in inspect.getmodule(var_value).__name__
except:
module_is_ell = False

if var_name not in FORBIDDEN_NAMES and not module_is_ell:
try:
dep, _, _ = lexical_closure(var_value, already_closed=already_closed, recursion_stack=recursion_stack.copy())
dep, _, _uses = lexical_closure(var_value, already_closed=already_closed, recursion_stack=recursion_stack.copy())
dependencies.append(dep)
uses.update(_uses)
except Exception as e:
_raise_error(f"Failed to capture the lexical closure of global or free variable {var_name}", e, recursion_stack)

def _process_module(var_name, var_value, modules, imports):
def _process_module(var_name, var_value, modules, imports, uses):
"""Process a module."""
if should_import(var_value):
imports.append(dill.source.getimport(var_value, alias=var_name))
else:
modules.append((var_name, var_value))

def _process_other_variable(var_name, var_value, dependencies):
def _process_other_variable(var_name, var_value, dependencies, uses):
"""Process variables that are not callables or modules."""
if isinstance(var_value, str) and '\n' in var_value:
dependencies.append(f"{var_name} = '''{var_value}'''")
Expand Down Expand Up @@ -400,4 +411,72 @@ def is_function_called(func_name, source_code):
return True

# If we've gone through all the nodes and haven't found a call to the function, it's not called
return False
return False

#!/usr/bin/env python
#
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
# Modified by: William Guss.
# Copyright (c) 2008-2016 California Institute of Technology.
# Copyright (c) 2016-2024 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
from dill.detect import nestedglobals
import inspect

def globalvars(func, recurse=True, builtin=False):
"""get objects defined in global scope that are referred to by func
return a dict of {name:object}"""
while hasattr(func, "__ell_func__"):
func = func.__ell_func__
if inspect.ismethod(func): func = func.__func__
while hasattr(func, "__ell_func__"):
func = func.__ell_func__
if inspect.isfunction(func):
globs = vars(inspect.getmodule(sum)).copy() if builtin else {}
# get references from within closure
orig_func, func = func, set()
for obj in orig_func.__closure__ or {}:
try:
cell_contents = obj.cell_contents
except ValueError: # cell is empty
pass
else:
_vars = globalvars(cell_contents, recurse, builtin) or {}
func.update(_vars) #XXX: (above) be wary of infinte recursion?
globs.update(_vars)
# get globals
globs.update(orig_func.__globals__ or {})
# get names of references
if not recurse:
func.update(orig_func.__code__.co_names)
else:
func.update(nestedglobals(orig_func.__code__))
# find globals for all entries of func
for key in func.copy(): #XXX: unnecessary...?
nested_func = globs.get(key)
if nested_func is orig_func:
#func.remove(key) if key in func else None
continue #XXX: globalvars(func, False)?
func.update(globalvars(nested_func, True, builtin))
elif inspect.iscode(func):
globs = vars(inspect.getmodule(sum)).copy() if builtin else {}
#globs.update(globals())
if not recurse:
func = func.co_names # get names
else:
orig_func = func.co_name # to stop infinite recursion
func = set(nestedglobals(func))
# find globals for all entries of func
for key in func.copy(): #XXX: unnecessary...?
if key is orig_func:
#func.remove(key) if key in func else None
continue #XXX: globalvars(func, False)?
nested_func = globs.get(key)
func.update(globalvars(nested_func, True, builtin))
else:
return {}
#NOTE: if name not in __globals__, then we skip it...
return dict((name,globs[name]) for name in func if name in globs)

37 changes: 34 additions & 3 deletions tests/test_closure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import wraps
import pytest
import math
from typing import Set, Any
Expand All @@ -9,6 +10,8 @@
get_referenced_names,
is_function_called,
)
import ell


def test_lexical_closure_simple_function():
def simple_func(x):
Expand Down Expand Up @@ -45,6 +48,7 @@ def func_with_default(x=10):

result, _, _ = lexical_closure(func_with_default)
print(result)

assert "def func_with_default(x=10):" in result

@pytest.mark.parametrize("value, expected", [
Expand All @@ -68,8 +72,8 @@ class DummyModule:

def test_get_referenced_names():
code = """
import math
result = math.sin(x) + math.cos(y)
import math
result = math.sin(x) + math.cos(y)
"""
referenced = get_referenced_names(code, "math")
print(referenced)
Expand Down Expand Up @@ -105,4 +109,31 @@ def dummy_func():

_, _, uses = lexical_closure(dummy_func, initial_call=True)
assert isinstance(uses, Set)
# You might want to add a more specific check for the content of 'uses'
# You might want to add a more specific check for the content of 'uses'


def test_lexical_closure_uses():

@ell.lm(model="gpt-4")
def dependency_func():
return "42"


@ell.lm(model="gpt-4")
def main_func():
return dependency_func()


# Check that uses is a set
assert isinstance(main_func.__ell_uses__, set)

# Check that the set contains exactly one item
assert dependency_func.__ell_hash__ in main_func.__ell_uses__
assert len(main_func.__ell_uses__) == 1
# Check that the item in the set starts with 'lmp-'
assert list(main_func.__ell_uses__)[0].startswith('lmp-')
assert len(dependency_func.__ell_uses__) == 0


if __name__ == "__main__":
test_lexical_closure_uses()
Loading

0 comments on commit 6156aec

Please sign in to comment.