diff --git a/hamilton/graph_utils.py b/hamilton/graph_utils.py index 8c3695b4..a629f739 100644 --- a/hamilton/graph_utils.py +++ b/hamilton/graph_utils.py @@ -3,8 +3,8 @@ from typing import Callable, List, Tuple -def is_submodule(child: ModuleType, parent: ModuleType): - return parent.__name__ in child.__name__ +def is_submodule(child: str, parent: str): + return parent in child def find_functions(function_module: ModuleType) -> List[Tuple[str, Callable]]: @@ -18,7 +18,7 @@ def valid_fn(fn): return ( inspect.isfunction(fn) and not fn.__name__.startswith("_") - and is_submodule(inspect.getmodule(fn), function_module) + and is_submodule(fn.__module__, function_module.__name__) ) return [f for f in inspect.getmembers(function_module, predicate=valid_fn)] diff --git a/hamilton/node.py b/hamilton/node.py index 505c5102..7148b287 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -189,7 +189,7 @@ def from_fn(fn: Callable, name: str = None) -> "Node": if name is None: name = fn.__name__ sig = inspect.signature(fn) - module = inspect.getmodule(fn).__name__ + module = fn.__module__ return Node( name, sig.return_annotation,