Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add extern symbols #236

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import inspect
from collections.abc import Callable
from dataclasses import dataclass, field
Expand All @@ -7,7 +8,7 @@

from hugr.serialization import ops, tys

from guppylang.ast_util import has_empty_body
from guppylang.ast_util import annotate_location, has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.custom import (
CustomCallChecker,
Expand All @@ -18,6 +19,7 @@
RawCustomFunctionDef,
)
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.extern import RawExternDef
from guppylang.definition.function import RawFunctionDef, parse_py_func
from guppylang.definition.parameter import TypeVarDef
from guppylang.definition.struct import RawStructDef
Expand Down Expand Up @@ -226,6 +228,43 @@ def dec(f: Callable[..., Any]) -> RawFunctionDecl:

return dec

def extern(
self,
module: GuppyModule,
name: str,
ty: str,
symbol: str | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dubious, I think name should and symbol should take it's place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation behind leaving them distinct was that the LLVM symbol name might not follow Python's naming conventions, so you'd want to pick a different name to refer to it in your program.

But I'm fine with making symbol required and name optional.

Copy link
Contributor

@doug-q doug-q Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be fine with symbol optional, in fact I weakly prefer it(now that I understand better). My reading of the original code was that name did not propagate into the symbol of the ConstExternSymbol. I think you should add tests for how the symbol of ConstExternSymbol is populated with and without the optional param (name or symbol, whichever you decide) here.

constant: bool = True,
) -> RawExternDef:
"""Adds an extern symbol to a module."""
try:
type_ast = ast.parse(ty, mode="eval").body
except SyntaxError:
err = f"Not a valid Guppy type: `{ty}`"
raise GuppyError(err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if frame := inspect.currentframe(): # noqa: SIM102
if caller_frame := frame.f_back: # noqa: SIM102
if caller_module := inspect.getmodule(caller_frame):
info = inspect.getframeinfo(caller_frame)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(type_ast, source, info.filename, 0)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [type_ast, *ast.walk(type_ast)]:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])

defn = RawExternDef(
DefId.fresh(module), name, None, symbol or name, constant, type_ast
)
module.register_def(defn)
return defn

def load(self, m: ModuleType | GuppyModule) -> None:
caller = self._get_python_caller()
if caller not in self._modules:
Expand Down
77 changes: 77 additions & 0 deletions guppylang/definition/extern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import ast
from dataclasses import dataclass, field

from hugr.serialization import ops

from guppylang.ast_util import AstNode
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV, VNode
from guppylang.tys.parsing import type_from_ast


@dataclass(frozen=True)
class RawExternDef(ParsableDef):
"""A raw extern symbol definition provided by the user."""

symbol: str
constant: bool
type_ast: ast.expr

description: str = field(default="extern", init=False)

def parse(self, globals: Globals) -> "ExternDef":
"""Parses and checks the user-provided signature of the function."""
return ExternDef(
self.id,
self.name,
self.defined_at,
type_from_ast(self.type_ast, globals, None),
self.symbol,
self.constant,
self.type_ast,
)


@dataclass(frozen=True)
class ExternDef(RawExternDef, ValueDef, CompilableDef):
"""An extern symbol definition."""

def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledExternDef":
"""Adds a Hugr constant node for the extern definition to the provided graph."""
custom_const = {
"symbol": self.symbol,
"typ": self.ty.to_hugr(),
"constant": self.constant,
}
value = ops.ExtensionValue(
extensions=["prelude"],
typ=self.ty.to_hugr(),
value=ops.CustomConst(c="ConstExternalSymbol", v=custom_const),
)
const_node = graph.add_constant(ops.Value(value), self.ty, parent)
return CompiledExternDef(
self.id,
self.name,
self.defined_at,
self.ty,
self.symbol,
self.constant,
self.type_ast,
const_node,
)


@dataclass(frozen=True)
class CompiledExternDef(ExternDef, CompiledValueDef):
"""An extern symbol definition that has been compiled to a Hugr constant."""

const_node: VNode

def load(
self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode
) -> OutPortV:
"""Loads the extern value into a local Hugr dataflow graph."""
return graph.add_load_constant(self.const_node.out_port(0)).out_port(0)
7 changes: 7 additions & 0 deletions tests/error/misc_errors/extern_bad_type.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: module = GuppyModule("test")
8:
9: guppy.extern(module, "x", ty="float[int]")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Type `float` is not parameterized
12 changes: 12 additions & 0 deletions tests/error/misc_errors/extern_bad_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.quantum import qubit

import guppylang.prelude.quantum as quantum


module = GuppyModule("test")

guppy.extern(module, "x", ty="float[int]")

module.compile()
11 changes: 10 additions & 1 deletion tests/error/test_misc_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pathlib
import pytest

from guppylang import GuppyModule, guppy
from guppylang.error import GuppyError
from tests.error.util import run_error_test

path = pathlib.Path(__file__).parent.resolve() / "misc_errors"
Expand All @@ -19,5 +21,12 @@


@pytest.mark.parametrize("file", files)
def test_type_errors(file, capsys):
def test_misc_errors(file, capsys):
run_error_test(file, capsys)


def test_extern_bad_type_syntax():
module = GuppyModule("test")

with pytest.raises(GuppyError, match="Not a valid Guppy type: `foo bar`"):
guppy.extern(module, name="x", ty="foo bar")
50 changes: 50 additions & 0 deletions tests/integration/test_extern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from hugr.serialization import ops

from guppylang.decorator import guppy
from guppylang.module import GuppyModule


def test_extern_float(validate):
module = GuppyModule("module")

guppy.extern(module, "ext", ty="float")

@guppy(module)
def main() -> float:
return ext + ext # noqa: F821

hg = module.compile()
validate(hg)

[c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)]
assert isinstance(c.v.root, ops.ExtensionValue)
assert c.v.root.value.v["symbol"] == "ext"


def test_extern_alt_symbol(validate):
module = GuppyModule("module")

guppy.extern(module, "ext", ty="int", symbol="foo")

@guppy(module)
def main() -> int:
return ext # noqa: F821

hg = module.compile()
validate(hg)

[c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)]
assert isinstance(c.v.root, ops.ExtensionValue)
assert c.v.root.value.v["symbol"] == "foo"

def test_extern_tuple(validate):
module = GuppyModule("module")

guppy.extern(module, "ext", ty="tuple[int, float]")

@guppy(module)
def main() -> float:
x, y = ext # noqa: F821
return x + y

validate(module.compile())
Loading