Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Load pytket circuit as a function definition #672

Merged
merged 22 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: uv sync --extra pytket

- name: Rerun `py(...)` expression tests and pytket lowering with tket2 installed
run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py
run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py tests/integration/test_pytket_circuits.py

test-coverage:
name: Check Python (3.13) with coverage
Expand Down
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ ine = "ine"
inot = "inot"
inout = "inout"
inouts = "inouts"
anc = "anc"
7 changes: 7 additions & 0 deletions guppylang/checker/errors/py_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,10 @@ class Tket2NotInstalled(Error):
@dataclass(frozen=True)
class InstallInstruction(Help):
message: ClassVar[str] = "Install tket2: `pip install tket2`"


@dataclass(frozen=True)
class PytketSignatureMismatch(Error):
title: ClassVar[str] = "Signature mismatch"
span_label: ClassVar[str] = "Signature {name} doesn't match provided pytket circuit"
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
name: str
24 changes: 24 additions & 0 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RawFunctionDef,
)
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.pytket_circuits import RawPytketDef
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import MissingModuleError, pretty_errors
Expand Down Expand Up @@ -57,6 +58,7 @@
FuncDefDecorator = Decorator[PyFunc, RawFunctionDef]
FuncDeclDecorator = Decorator[PyFunc, RawFunctionDecl]
CustomFuncDecorator = Decorator[PyFunc, RawCustomFunctionDef]
PytketDecorator = Decorator[PyFunc, RawPytketDef]
ClassDecorator = Decorator[PyClass, PyClass]
OpaqueTypeDecorator = Decorator[PyClass, OpaqueTypeDef]
StructDecorator = Decorator[PyClass, RawStructDef]
Expand Down Expand Up @@ -468,6 +470,28 @@ def registered_modules(self) -> KeysView[ModuleIdentifier]:
"""Returns a list of all currently registered modules for local contexts."""
return self._modules.keys()

@pretty_errors
def pytket(
self, input_circuit: Any, module: GuppyModule | None = None
) -> PytketDecorator:
"""Adds a pytket circuit function definition with explicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.pytket"
try:
import pytket

if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None

except ImportError:
raise TypeError(err_msg) from None

mod = module or self.get_module()

def func(f: PyFunc) -> RawPytketDef:
return mod.register_pytket_func(f, input_circuit)

return func


class _GuppyDummy:
"""A dummy class with the same interface as `@guppy` that is used during sphinx
Expand Down
27 changes: 11 additions & 16 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import ClassVar

from hugr import Node, Wire
from hugr import tys as ht
from hugr.build import function as hf
from hugr.build.dfg import DefinitionBuilder, OpVar

Expand All @@ -13,14 +12,19 @@
from guppylang.checker.func_checker import check_signature
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.function import PyFunc, parse_py_func
from guppylang.definition.function import (
PyFunc,
compile_call,
load_with_args,
parse_py_func,
)
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.diagnostic import Error
from guppylang.error import GuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import Type, type_to_row
from guppylang.tys.ty import Type


@dataclass(frozen=True)
Expand Down Expand Up @@ -121,9 +125,8 @@ def load_with_args(
node: AstNode,
) -> Wire:
"""Loads the function as a value into a local Hugr dataflow graph."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
return dfg.builder.load_function(self.declaration, func_ty, type_args)
# Use implementation from function definition.
return load_with_args(type_args, dfg, self.ty, self.declaration)

def compile_call(
self,
Expand All @@ -134,13 +137,5 @@ def compile_call(
node: AstNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(self.ty.output))
call = dfg.builder.call(
self.declaration, *args, instantiation=func_ty, type_args=type_args
)
return CallReturnWires(
regular_returns=list(call[:num_returns]),
inout_returns=list(call[num_returns:]),
)
# Use implementation from function definition.
return compile_call(args, type_args, dfg, self.ty, self.declaration)
46 changes: 33 additions & 13 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hugr.tys as ht
from hugr import Wire
from hugr.build.dfg import DefinitionBuilder, OpVar
from hugr.hugr.node_port import ToNode
from hugr.package import FuncDefnPointer

from guppylang.ast_util import AstNode, annotate_location, with_loc, with_type
Expand Down Expand Up @@ -199,9 +200,7 @@ def load_with_args(
node: AstNode,
) -> Wire:
"""Loads the function as a value into a local Hugr dataflow graph."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
return dfg.builder.load_function(self.func_def, func_ty, type_args)
return load_with_args(type_args, dfg, self.ty, self.func_def)

def compile_call(
self,
Expand All @@ -212,22 +211,43 @@ def compile_call(
node: AstNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(self.ty.output))
call = dfg.builder.call(
self.func_def, *args, instantiation=func_ty, type_args=type_args
)
return CallReturnWires(
regular_returns=list(call[:num_returns]),
inout_returns=list(call[num_returns:]),
)
return compile_call(args, type_args, dfg, self.ty, self.func_def)

def compile_inner(self, globals: CompiledGlobals) -> None:
"""Compiles the body of the function."""
compile_global_func_def(self, self.func_def, globals)


def load_with_args(
type_args: Inst,
dfg: DFContainer,
ty: FunctionType,
func: ToNode,
) -> Wire:
"""Loads the function as a value into a local Hugr dataflow graph."""
func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
return dfg.builder.load_function(func, func_ty, type_args)


def compile_call(
args: list[Wire],
type_args: Inst,
dfg: DFContainer,
ty: FunctionType,
func: ToNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(ty.output))
call = dfg.builder.call(func, *args, instantiation=func_ty, type_args=type_args)
return CallReturnWires(
regular_returns=list(call[:num_returns]),
inout_returns=list(call[num_returns:]),
)


def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | None]:
source_lines, line_offset = inspect.getsourcelines(f)
source = "".join(source_lines) # Lines already have trailing \n's
Expand Down
Loading
Loading