Skip to content

Commit

Permalink
fix: Keep track of definitions that are implicitly imported (#481)
Browse files Browse the repository at this point in the history
Fixes #480
  • Loading branch information
mark-koch authored Sep 11, 2024
1 parent 1b73032 commit a89f225
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
11 changes: 4 additions & 7 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def load(
module = imp.id.module
assert module is not None
module.check()
defs[imp.id] = module._checked_defs[imp.id]
names[alias or imp.name] = imp.id
modules.add(module)
elif isinstance(imp, GuppyModule):
Expand All @@ -123,7 +122,6 @@ def load(
defn = ModuleDef(def_id, name, None, imp._globals)
defs[def_id] = defn
names[name] = def_id
defs |= imp._checked_defs
modules.add(imp)
elif isinstance(imp, ModuleType):
mod = find_guppy_module_in_py_module(imp)
Expand All @@ -135,16 +133,15 @@ def load(
# Also include any impls that are defined by the imported modules
impls: dict[DefId, dict[str, DefId]] = {}
for module in modules:
# We need to include everything defined in the module, including stuff that
# is not directly imported, in order to lower everything into a single Hugr
defs |= module._imported_checked_defs
defs |= module._checked_defs
# We also need to include any impls that are transitively imported
all_globals = module._imported_globals | module._globals
all_checked_defs = module._imported_checked_defs | module._checked_defs
for def_id in all_globals.impls:
impls.setdefault(def_id, {})
impls[def_id] |= all_globals.impls[def_id]
defs |= {
def_id: all_checked_defs[def_id]
for def_id in all_globals.impls[def_id].values()
}
self._imported_globals |= Globals(dict(defs), names, impls, {})
self._imported_checked_defs |= defs

Expand Down
25 changes: 25 additions & 0 deletions tests/integration/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,28 @@ def test(x: int) -> int:
return implicit_mod.foo(x)

validate(module.compile())


def test_private_func(validate):
# First, define a module with a public function
# that calls an internal one
internal_module = GuppyModule("test_internal")

@guppy(internal_module)
def _internal(x: int) -> int:
return x

@guppy(internal_module)
def g(x: int) -> int:
return _internal(x)

# The test module
module = GuppyModule("test")
module.load_all(internal_module)

@guppy(module)
def f(x: int) -> int:
return g(x)

hugr = module.compile()
validate(hugr)

0 comments on commit a89f225

Please sign in to comment.