Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Load pytket circuit as a function definition #672

Merged
merged 22 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: uv sync --extra pytket

- name: Rerun `py(...)` expression tests and pytket lowering with tket2 installed
run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py
run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py /tests/integration/test_pytket_circuits.py

test-coverage:
name: Check Python (3.13) with coverage
Expand Down
10 changes: 1 addition & 9 deletions guppylang/checker/errors/py_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,5 @@ class InstallInstruction(Help):
@dataclass(frozen=True)
class PytketSignatureMismatch(Error):
title: ClassVar[str] = "Signature mismatch"
span_label: ClassVar[str] = (
"Function signature {name} doesn't match provided pytket circuit"
)
span_label: ClassVar[str] = "Signature {name} doesn't match provided pytket circuit"
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
name: str


@dataclass(frozen=True)
class PytketNotCircuit(Error):
title: ClassVar[str] = "Input not circuit"
span_label: ClassVar[str] = "Provided input is not a pytket circuit"
10 changes: 10 additions & 0 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,16 @@ def pytket(
self, input_circuit: Any, module: GuppyModule | None = None
) -> PytketDecorator:
"""Adds a pytket circuit function definition with explicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.pytket"
try:
import pytket

if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None

except ImportError:
raise TypeError(err_msg) from None

mod = module or self.get_module()

def func(f: PyFunc) -> RawPytketDef:
Expand Down
68 changes: 26 additions & 42 deletions guppylang/definition/pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
from dataclasses import dataclass, field
from typing import Any, cast

import hugr.tys as ht
from hugr import Hugr, Node, Wire
import hugr.build.function as hf
from hugr import Hugr, OutPort, Wire, ops, tys, val
from hugr.build.dfg import DefinitionBuilder, OpVar
from hugr.ops import FuncDefn, Input
from hugr.tys import Bool

from guppylang.ast_util import AstNode, has_empty_body, with_loc
from guppylang.checker.core import Context, Globals, PyScope
from guppylang.checker.errors.py_errors import (
PytketNotCircuit,
PytketSignatureMismatch,
Tket2NotInstalled,
)
Expand All @@ -25,6 +22,7 @@
from guppylang.definition.declaration import BodyNotEmptyError
from guppylang.definition.function import (
PyFunc,
compile_call,
load_with_args,
parse_py_func,
)
Expand All @@ -41,7 +39,6 @@
InputFlags,
Type,
row_to_type,
type_to_row,
)


Expand Down Expand Up @@ -100,13 +97,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
circuit_signature.inputs == stub_signature.inputs
and circuit_signature.output == stub_signature.output
):
# TODO: Implement pretty-printing for signatures in order to add
# a note for expected vs. actual types.
raise GuppyError(PytketSignatureMismatch(func_ast, self.name))
except ImportError:
err = Tket2NotInstalled(func_ast)
err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None))
raise GuppyError(err) from None
else:
raise GuppyError(PytketNotCircuit(func_ast))
except ImportError:
pass

Expand Down Expand Up @@ -154,26 +151,29 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef"
mapping = module.hugr.insert_hugr(circ)
hugr_func = mapping[circ.root]

# We need to remove input bits from both signature and input node.
node_data = module.hugr.get(hugr_func)
# TODO: Error if hugr isn't FuncDefn?
if node_data and isinstance(node_data.op, FuncDefn):
func_defn = node_data.op
func_defn.f_name = self.name
num_bools = func_defn.inputs.count(Bool)
for _ in range(num_bools):
func_defn.inputs.remove(Bool)
func_type = self.ty.to_hugr_poly()
outer_func = module.define_function(
self.name, func_type.body.input, func_type.body.output
)

# Initialise every input bit in the circuit as false.
# TODO: Provide the option for the user to pass this input as well.
bool_wires: list[OutPort] = []
for child in module.hugr.children(hugr_func):
node_data = module.hugr.get(child)
if node_data and isinstance(node_data.op, Input):
if node_data and isinstance(node_data.op, ops.Input):
input_types = node_data.op.types
num_bools = input_types.count(Bool)
num_bools = input_types.count(tys.Bool)
for _ in range(num_bools):
input_types.remove(Bool)
bool_node = outer_func.load(val.FALSE)
bool_wires.append(*bool_node.outputs())
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

call_node = outer_func.call(
hugr_func, *(list(outer_func.inputs()) + bool_wires)
)
# Pytket circuit hugr has qubit and bool wires in the opposite order.
outer_func.set_outputs(*reversed(list(call_node.outputs())))
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

else:
raise GuppyError(PytketNotCircuit(self.defined_at))
except ImportError:
pass

Expand All @@ -184,7 +184,7 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef"
self.ty,
self.python_scope,
self.input_circuit,
hugr_func,
outer_func,
)

def check_call(
Expand Down Expand Up @@ -220,7 +220,7 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef):
func_df: The Hugr function definition.
"""

func_def: Node
func_def: hf.Function

def load_with_args(
self,
Expand All @@ -242,21 +242,5 @@ def compile_call(
node: AstNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(self.ty.output))
call = dfg.builder.call(
self.func_def, *args, instantiation=func_ty, type_args=type_args
)
# Negative index slicing doesn't work when num_returns is 0.
if num_returns == 0:
return CallReturnWires(
regular_returns=[],
inout_returns=list(call[0:]),
)
else:
# Circuit function returns are the other way round to Guppy returns.
return CallReturnWires(
regular_returns=list(call[-num_returns:]),
inout_returns=list(call[:-num_returns]),
)
# Use implementation from function definition.
return compile_call(args, type_args, dfg, self.ty, self.func_def)
9 changes: 9 additions & 0 deletions tests/error/py_errors/sig_mismatch.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Error: Signature mismatch (at <In [69]>:15:0)
|
13 |
14 | @guppy.pytket(circ, module)
15 | def guppy_circ(q: qubit) -> None: ...
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Function signature guppy_circ doesn't match provided pytket
| circuit

Guppy compilation failed due to 1 previous error
22 changes: 22 additions & 0 deletions tests/error/py_errors/sig_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pytket import Circuit

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit

circ = Circuit(2)
circ.X(0)
circ.Y(1)

module = GuppyModule("test")
module.load(qubit)

@guppy.pytket(circ, module)
def guppy_circ(q: qubit) -> None: ...

@guppy(module)
def foo(q: qubit) -> None:
guppy_circ(q)


module.compile()
15 changes: 9 additions & 6 deletions tests/integration/test_pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def foo(q: qubit) -> None:


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("remove this")
def test_multi_qubit_circuit(validate):
from pytket import Circuit

Expand Down Expand Up @@ -68,13 +69,14 @@ def test_measure(validate):
def guppy_circ(q: qubit) -> bool: ...

@guppy(module)
def foo(q: qubit) -> None:
result = guppy_circ(q)
def foo(q: qubit) -> bool:
return guppy_circ(q)

validate(module.compile())


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("remove this")
def test_measure_multiple(validate):
from pytket import Circuit

Expand All @@ -89,13 +91,14 @@ def test_measure_multiple(validate):
def guppy_circ(q1: qubit, q2: qubit) -> tuple[bool, bool]: ...

@guppy(module)
def foo(q1: qubit, q2: qubit) -> None:
result = guppy_circ(q1, q2)
def foo(q1: qubit, q2: qubit) -> tuple[bool, bool]:
return guppy_circ(q1, q2)

validate(module.compile())


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("remove this")
def test_measure_not_last(validate):
from pytket import Circuit

Expand All @@ -111,8 +114,8 @@ def test_measure_not_last(validate):
def guppy_circ(q: qubit) -> bool: ...

@guppy(module)
def foo(q: qubit) -> None:
result = guppy_circ(q)
def foo(q: qubit) -> bool:
return guppy_circ(q)

validate(module.compile())

Expand Down
Loading