Skip to content

Commit

Permalink
feat: Add extern symbols (#236)
Browse files Browse the repository at this point in the history
mark-koch authored Jun 10, 2024
1 parent 4c2b5a9 commit 977ccd8
Showing 6 changed files with 196 additions and 2 deletions.
41 changes: 40 additions & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import inspect
from collections.abc import Callable
from dataclasses import dataclass, field
@@ -7,7 +8,7 @@

from hugr.serialization import ops, tys

from guppylang.ast_util import has_empty_body
from guppylang.ast_util import annotate_location, has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.custom import (
CustomCallChecker,
@@ -18,6 +19,7 @@
RawCustomFunctionDef,
)
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.extern import RawExternDef
from guppylang.definition.function import RawFunctionDef, parse_py_func
from guppylang.definition.parameter import TypeVarDef
from guppylang.definition.struct import RawStructDef
@@ -226,6 +228,43 @@ def dec(f: Callable[..., Any]) -> RawFunctionDecl:

return dec

def extern(
self,
module: GuppyModule,
name: str,
ty: str,
symbol: str | None = None,
constant: bool = True,
) -> RawExternDef:
"""Adds an extern symbol to a module."""
try:
type_ast = ast.parse(ty, mode="eval").body
except SyntaxError:
err = f"Not a valid Guppy type: `{ty}`"
raise GuppyError(err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if frame := inspect.currentframe(): # noqa: SIM102
if caller_frame := frame.f_back: # noqa: SIM102
if caller_module := inspect.getmodule(caller_frame):
info = inspect.getframeinfo(caller_frame)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(type_ast, source, info.filename, 0)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [type_ast, *ast.walk(type_ast)]:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])

defn = RawExternDef(
DefId.fresh(module), name, None, symbol or name, constant, type_ast
)
module.register_def(defn)
return defn

def load(self, m: ModuleType | GuppyModule) -> None:
caller = self._get_python_caller()
if caller not in self._modules:
77 changes: 77 additions & 0 deletions guppylang/definition/extern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import ast
from dataclasses import dataclass, field

from hugr.serialization import ops

from guppylang.ast_util import AstNode
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV, VNode
from guppylang.tys.parsing import type_from_ast


@dataclass(frozen=True)
class RawExternDef(ParsableDef):
"""A raw extern symbol definition provided by the user."""

symbol: str
constant: bool
type_ast: ast.expr

description: str = field(default="extern", init=False)

def parse(self, globals: Globals) -> "ExternDef":
"""Parses and checks the user-provided signature of the function."""
return ExternDef(
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
self.symbol,
self.constant,
self.type_ast,
)


@dataclass(frozen=True)
class ExternDef(RawExternDef, ValueDef, CompilableDef):
"""An extern symbol definition."""

def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledExternDef":
"""Adds a Hugr constant node for the extern definition to the provided graph."""
custom_const = {
"symbol": self.symbol,
"typ": self.ty.to_hugr(),
"constant": self.constant,
}
value = ops.ExtensionValue(
extensions=["prelude"],
typ=self.ty.to_hugr(),
value=ops.CustomConst(c="ConstExternalSymbol", v=custom_const),
)
const_node = graph.add_constant(ops.Value(value), self.ty, parent)
return CompiledExternDef(
self.id,
self.name,
self.defined_at,
self.ty,
self.symbol,
self.constant,
self.type_ast,
const_node,
)


@dataclass(frozen=True)
class CompiledExternDef(ExternDef, CompiledValueDef):
"""An extern symbol definition that has been compiled to a Hugr constant."""

const_node: VNode

def load(
self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode
) -> OutPortV:
"""Loads the extern value into a local Hugr dataflow graph."""
return graph.add_load_constant(self.const_node.out_port(0)).out_port(0)
7 changes: 7 additions & 0 deletions tests/error/misc_errors/extern_bad_type.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: module = GuppyModule("test")
8:
9: guppy.extern(module, "x", ty="float[int]")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Type `float` is not parameterized
12 changes: 12 additions & 0 deletions tests/error/misc_errors/extern_bad_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.quantum import qubit

import guppylang.prelude.quantum as quantum


module = GuppyModule("test")

guppy.extern(module, "x", ty="float[int]")

module.compile()
11 changes: 10 additions & 1 deletion tests/error/test_misc_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pathlib
import pytest

from guppylang import GuppyModule, guppy
from guppylang.error import GuppyError
from tests.error.util import run_error_test

path = pathlib.Path(__file__).parent.resolve() / "misc_errors"
@@ -19,5 +21,12 @@


@pytest.mark.parametrize("file", files)
def test_type_errors(file, capsys):
def test_misc_errors(file, capsys):
run_error_test(file, capsys)


def test_extern_bad_type_syntax():
module = GuppyModule("test")

with pytest.raises(GuppyError, match="Not a valid Guppy type: `foo bar`"):
guppy.extern(module, name="x", ty="foo bar")
50 changes: 50 additions & 0 deletions tests/integration/test_extern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from hugr.serialization import ops

from guppylang.decorator import guppy
from guppylang.module import GuppyModule


def test_extern_float(validate):
module = GuppyModule("module")

guppy.extern(module, "ext", ty="float")

@guppy(module)
def main() -> float:
return ext + ext # noqa: F821

hg = module.compile()
validate(hg)

[c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)]
assert isinstance(c.v.root, ops.ExtensionValue)
assert c.v.root.value.v["symbol"] == "ext"


def test_extern_alt_symbol(validate):
module = GuppyModule("module")

guppy.extern(module, "ext", ty="int", symbol="foo")

@guppy(module)
def main() -> int:
return ext # noqa: F821

hg = module.compile()
validate(hg)

[c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)]
assert isinstance(c.v.root, ops.ExtensionValue)
assert c.v.root.value.v["symbol"] == "foo"

def test_extern_tuple(validate):
module = GuppyModule("module")

guppy.extern(module, "ext", ty="tuple[int, float]")

@guppy(module)
def main() -> float:
x, y = ext # noqa: F821
return x + y

validate(module.compile())

0 comments on commit 977ccd8

Please sign in to comment.