Skip to content

Commit

Permalink
Closes #250. Closes #234. Fix closuring via site package imports
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Sep 26, 2024
1 parent 6ca5641 commit 9c4de49
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 49 deletions.
3 changes: 2 additions & 1 deletion src/ell/lmp/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def model_call(
if exempt_from_tracking:
return model_call
else:
return _track(model_call, forced_dependencies=dict(tools=tools))
# XXX: Analyze decorators with AST instead.
return _track(model_call, forced_dependencies=dict(tools=tools, response_format=api_params.get("response_format", {})))
return parameterized_lm_decorator


Expand Down
49 changes: 28 additions & 21 deletions src/ell/util/closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,46 +148,50 @@ def _process_dependencies(func, globals_and_frees, already_closed, recursion_sta
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack, uses)

return dependencies, imports, modules

def _process_default_kwargs(func, dependencies, already_closed, recursion_stack, uses):
"""Process default keyword arguments of a function."""
"""Process default keyword arguments and annotations of a function."""
ps = inspect.signature(func).parameters
default_kwargs = collections.OrderedDict({k: v.default for k, v in ps.items() if v.default is not inspect.Parameter.empty})
for name, val in default_kwargs.items():
_process_signature_dependency(val, dependencies, already_closed, recursion_stack, uses, name)

def _process_signature_dependency(val, dependencies, already_closed, recursion_stack, uses, name : Optional[str] = None):
# Todo: Buidl general cattr like utility for unstructureing python objects with ooks that keep track of state variables.
# Todo: break up closure into types and fucntions.
for name, param in ps.items():
if param.default is not inspect.Parameter.empty:
_process_signature_dependency(param.default, dependencies, already_closed, recursion_stack, uses, name)
if param.annotation is not inspect.Parameter.empty:
_process_signature_dependency(param.annotation, dependencies, already_closed, recursion_stack, uses, f"{name}_annotation")
if func.__annotations__.get('return') is not None:
_process_signature_dependency(func.__annotations__['return'], dependencies, already_closed, recursion_stack, uses, "return_annotation")
# XXX: In order to properly analyze this we should walk the AST rather than inspexting the signature; e.g. Field is FieldInfo not Field.
# I don't care about the actual default at time of execution just the symbols required to statically reproduce the prompt.

def _process_signature_dependency(val, dependencies, already_closed, recursion_stack, uses, name: Optional[str] = None):
# Todo: Build general cattr like utility for unstructuring python objects with hooks that keep track of state variables.
# Todo: break up closure into types and functions.
# XXX: This is not exhaustive, we should determine should import on all dependencies

if (name not in FORBIDDEN_NAMES):
if name not in FORBIDDEN_NAMES:
try:
dep = None
_uses = None
if isinstance(val, (types.FunctionType, type, types.MethodType)):
if isinstance(val, (types.FunctionType, types.MethodType)):
dep, _, _uses = lexical_closure(val, already_closed=already_closed, recursion_stack=recursion_stack.copy())

elif isinstance(val, (list, tuple, set)): # Todo: Figure out recursive ypye closurex
# print(val)
elif isinstance(val, (list, tuple, set)):
for item in val:
_process_signature_dependency(item, dependencies, already_closed, recursion_stack, uses)
else:
val_class = val if isinstance(val, type) else val.__class__
try:
is_builtin = (val.__class__.__module__ == "builtins" or val.__class__.__module__ == "__builtins__" )
is_builtin = (val_class.__module__ == "builtins" or val_class.__module__ == "__builtins__")
except:
is_builtin = False

if not is_builtin:
if should_import(val.__class__.__module__):

dependencies.append(dill.source.getimport(val.__class__, alias=val.__class__.__name__))
if should_import(val_class.__module__):
dependencies.append(dill.source.getimport(val_class, alias=val_class.__name__))
else:
dep, _, _uses = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
dep, _, _uses = lexical_closure(val_class, already_closed=already_closed, recursion_stack=recursion_stack.copy())

if dep: dependencies.append(dep)
if _uses: uses.update(_uses)
except Exception as e:
_raise_error(f"Failed to capture the lexical closure of default parameter {name}", e, recursion_stack)
_raise_error(f"Failed to capture the lexical closure of parameter or annotation {name}", e, recursion_stack)


def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack , uses):
Expand Down Expand Up @@ -499,10 +503,13 @@ def globalvars(func, recurse=True, builtin=False):
continue #XXX: globalvars(func, False)?
nested_func = globs.get(key)
func.update(globalvars(nested_func, True, builtin))
# elif inspect.isclass(func):
# XXX: We need to get lexical closures of all the methods and attributes of the class.\
# In the future we should exhaustively walk the AST here.
else:
return {}
#NOTE: if name not in __globals__, then we skip it...
return dict((name,globs[name]) for name in func if name in globs)


# XXX: This is a mess. COuld probably be about 100 lines of code max.
#
24 changes: 1 addition & 23 deletions src/ell/util/should_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,76 +17,54 @@ def should_import(module_name: str, raise_on_error: bool = False) -> bool:
Returns:
bool: True if the module should be imported (i.e., it's a third-party module), False otherwise.
"""
print(f"Checking if module '{module_name}' should be imported")
if module_name.startswith("ell"):
print(f"Module '{module_name}' starts with 'ell', returning True")
return True
try:
try:
print(f"Attempting to find spec for module '{module_name}'")
spec = importlib.util.find_spec(module_name)
print(f"Spec for module '{module_name}': {spec}")
except ValueError as e:
print(f"ValueError occurred while finding spec for '{module_name}': {e}")
except ValueError:
return False
if spec is None:
print(f"Spec for module '{module_name}' is None, returning False")
return False

origin = spec.origin
print(f"Origin for module '{module_name}': {origin}")
if origin is None:
print(f"Origin for module '{module_name}' is None, returning False")
return False
if spec.has_location:
print(f"Module '{module_name}' has location")
origin_path = Path(origin).resolve()
print(f"Resolved origin path: {origin_path}")

site_packages = list(site.getsitepackages()) + (list(site.getusersitepackages()) if isinstance(site.getusersitepackages(), list) else [site.getusersitepackages()])
print(f"Site packages: {site_packages}")

additional_paths = [Path(p).resolve() for p in sys.path if Path(p).resolve() not in map(Path, site_packages)]
print(f"Additional paths: {additional_paths}")

project_root = Path(os.environ.get("ELL_PROJECT_ROOT", os.getcwd())).resolve()
print(f"Project root: {project_root}")

site_packages_paths = [Path(p).resolve() for p in site_packages]
stdlib_path = sysconfig.get_paths().get("stdlib")
if stdlib_path:
site_packages_paths.append(Path(stdlib_path).resolve())
print(f"Site packages paths (including stdlib): {site_packages_paths}")

additional_paths = [Path(p).resolve() for p in additional_paths]
local_paths = [project_root]
print(f"Local paths: {local_paths}")

cwd = Path.cwd().resolve()
additional_paths = [path for path in additional_paths if path != cwd]
print(f"Additional paths (excluding cwd): {additional_paths}")

for pkg in site_packages_paths:
if origin_path.is_relative_to(pkg):
print(f"Module '{module_name}' is relative to site package {pkg}, returning True")
return True

for path in additional_paths:
if origin_path.is_relative_to(path):
print(f"Module '{module_name}' is relative to additional path {path}, returning False")
return False

for local in local_paths:
if origin_path.is_relative_to(local):
print(f"Module '{module_name}' is relative to local path {local}, returning False")
return False

print(f"Module '{module_name}' doesn't match any criteria, returning True")
return True

except Exception as e:
print(f"Failed to find spec for {module_name}. Please report to https://github.com/MadcowD/ell/issues. Error: {e}")
if raise_on_error:
raise e
# raise e
return True
8 changes: 4 additions & 4 deletions tests/test_should_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def test_should_import_exception_handling(mock_project_root, mock_sysconfig_path
patch("sysconfig.get_paths", return_value=mock_sysconfig_paths), \
patch("os.environ.get", return_value=str(mock_project_root)):

result = should_import("any_module")
assert result == True, "Function should return True when an exception occurs and raise_on_error is False"
with pytest.raises(Exception) as exc_info:
should_import("any_module", raise_on_error=True)
assert "Test Exception" in str(exc_info.value)

captured = capsys.readouterr()
assert "Failed to find spec for any_module" in captured.out
assert should_import("any_module") == True, "Function should return True when an exception occurs and raise_on_error is False"

def test_should_import_raise_on_error(mock_project_root, mock_sysconfig_paths, mock_site_packages, monkeypatch):
"""
Expand Down

0 comments on commit 9c4de49

Please sign in to comment.