From a9e40908bf18f5df486d55d8d75ce01f99345f12 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Mon, 25 Nov 2024 17:33:45 +0000 Subject: [PATCH 01/21] Add empty classes for pytket circuit definitions --- guppylang/definition/pytket_circ.py | 78 +++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 guppylang/definition/pytket_circ.py diff --git a/guppylang/definition/pytket_circ.py b/guppylang/definition/pytket_circ.py new file mode 100644 index 00000000..938fffae --- /dev/null +++ b/guppylang/definition/pytket_circ.py @@ -0,0 +1,78 @@ +import ast +import inspect +import textwrap +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import hugr.build.function as hf +import hugr.tys as ht +from hugr.build.dfg import DefinitionBuilder, OpVar + + +from guppylang.checker.core import Globals, PyScope +from guppylang.definition.common import ( + CheckableDef, + CompilableDef, + ParsableDef, +) +from guppylang.definition.value import CompiledCallableDef +from guppylang.error import GuppyError +from guppylang.span import SourceMap + +PyFunc = Callable[..., Any] + +@dataclass(frozen=True) +class RawPytketDef(ParsableDef): + """A raw function stub definition describing the signature of a circuit. + + Args: + id: The unique definition identifier. + name: The name of the function stub. + defined_at: The AST node where the stub was defined. + python_func: The Python function stub. + python_scope: The Python scope where the function stub was defined. + """ + + python_func: PyFunc + python_scope: PyScope + + description: str = field(default="pytket circuit", init=False) + + def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": + """Parses and checks the user-provided signature matches the user-provided circuit.""" + pass + + +@dataclass(frozen=True) +class ParsedPytketDef(CheckableDef, CompilableDef): + """A circuit definition with parsed and checked signature. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + """ + + description: str = field(default="pytket circuit", init=False) + + def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": + """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" + pass + + +class CompiledPytketDef(CompiledCallableDef): + """A function definition with a corresponding Hugr node. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + func_df: The Hugr function definition. + """ + + func_def: hf.Function \ No newline at end of file From 76ca28c1ea98a14f2d7f05b7472d54328e86c282 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Mon, 25 Nov 2024 17:34:02 +0000 Subject: [PATCH 02/21] Start implementing pytket circuit decorator --- guppylang/decorator.py | 17 +++++++++++++++++ guppylang/module.py | 9 +++++++++ 2 files changed, 26 insertions(+) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 2a4244ad..2eacec75 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -30,6 +30,7 @@ RawFunctionDef, ) from guppylang.definition.parameter import ConstVarDef, TypeVarDef +from guppylang.definition.pytket_circ import RawPytketDef from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import MissingModuleError, pretty_errors @@ -57,6 +58,7 @@ FuncDefDecorator = Decorator[PyFunc, RawFunctionDef] FuncDeclDecorator = Decorator[PyFunc, RawFunctionDecl] CustomFuncDecorator = Decorator[PyFunc, RawCustomFunctionDef] +PytketDecorator = Decorator[PyFunc, RawPytketDef] ClassDecorator = Decorator[PyClass, PyClass] OpaqueTypeDecorator = Decorator[PyClass, OpaqueTypeDef] StructDecorator = Decorator[PyClass, RawStructDef] @@ -465,6 +467,21 @@ def compile_function(self, f_def: RawFunctionDef) -> FuncDefnPointer: def registered_modules(self) -> KeysView[ModuleIdentifier]: """Returns a list of all currently registered modules for local contexts.""" return self._modules.keys() + + # TODO: Circuit type and propagation + @pretty_errors + def pytket_circ( + self, + circuit: Any, + name: str = "", + module: GuppyModule | None = None + + ) -> PytketDecorator: + """Adds a pytket circuit function definition.""" + module = module or self.get_module() + def dec(f: PyFunc) -> RawPytketDef: + return module.register_pytket_func(f, circuit) + return dec class _GuppyDummy: diff --git a/guppylang/module.py b/guppylang/module.py index 792270c6..930c5104 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -27,6 +27,7 @@ from guppylang.definition.function import RawFunctionDef from guppylang.definition.module import ModuleDef from guppylang.definition.parameter import ParamDef +from guppylang.definition.pytket_circ import RawPytketDef from guppylang.definition.struct import CheckedStructDef from guppylang.definition.ty import TypeDef from guppylang.error import pretty_errors @@ -236,6 +237,14 @@ def register_func_decl( decl = RawFunctionDecl(DefId.fresh(self), f.__name__, None, f, get_py_scope(f)) self.register_def(decl, instance) return decl + + def register_pytket_func( + self, f: PyFunc, instance: TypeDef | None = None + ) -> RawPytketDef: + """Registers a pytket circuit function as belonging to this Guppy module.""" + decl = RawPytketDef(DefId.fresh(self), f.__name__, None, f, get_py_scope(f)) + self.register_def(decl, instance) + return decl def _register_buffered_instance_funcs(self, instance: TypeDef) -> None: assert self._instance_func_buffer is not None From 76fb0a95b57c11f79c29716f146f57bae2369dc1 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Tue, 26 Nov 2024 09:19:24 +0000 Subject: [PATCH 03/21] Run format and check --- guppylang/decorator.py | 10 ++++------ guppylang/definition/pytket_circ.py | 15 +++++---------- guppylang/module.py | 2 +- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 2eacec75..ab91fb2d 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -467,20 +467,18 @@ def compile_function(self, f_def: RawFunctionDef) -> FuncDefnPointer: def registered_modules(self) -> KeysView[ModuleIdentifier]: """Returns a list of all currently registered modules for local contexts.""" return self._modules.keys() - + # TODO: Circuit type and propagation @pretty_errors def pytket_circ( - self, - circuit: Any, - name: str = "", - module: GuppyModule | None = None - + self, circuit: Any, name: str = "", module: GuppyModule | None = None ) -> PytketDecorator: """Adds a pytket circuit function definition.""" module = module or self.get_module() + def dec(f: PyFunc) -> RawPytketDef: return module.register_pytket_func(f, circuit) + return dec diff --git a/guppylang/definition/pytket_circ.py b/guppylang/definition/pytket_circ.py index 938fffae..69f4b1a1 100644 --- a/guppylang/definition/pytket_circ.py +++ b/guppylang/definition/pytket_circ.py @@ -1,15 +1,10 @@ -import ast -import inspect -import textwrap from collections.abc import Callable from dataclasses import dataclass, field from typing import Any import hugr.build.function as hf -import hugr.tys as ht from hugr.build.dfg import DefinitionBuilder, OpVar - from guppylang.checker.core import Globals, PyScope from guppylang.definition.common import ( CheckableDef, @@ -17,11 +12,11 @@ ParsableDef, ) from guppylang.definition.value import CompiledCallableDef -from guppylang.error import GuppyError from guppylang.span import SourceMap PyFunc = Callable[..., Any] + @dataclass(frozen=True) class RawPytketDef(ParsableDef): """A raw function stub definition describing the signature of a circuit. @@ -40,8 +35,9 @@ class RawPytketDef(ParsableDef): description: str = field(default="pytket circuit", init=False) def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": - """Parses and checks the user-provided signature matches the user-provided circuit.""" - pass + """Parses and checks the user-provided signature matches the + user-provided circuit. + """ @dataclass(frozen=True) @@ -60,7 +56,6 @@ class ParsedPytketDef(CheckableDef, CompilableDef): def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" - pass class CompiledPytketDef(CompiledCallableDef): @@ -75,4 +70,4 @@ class CompiledPytketDef(CompiledCallableDef): func_df: The Hugr function definition. """ - func_def: hf.Function \ No newline at end of file + func_def: hf.Function diff --git a/guppylang/module.py b/guppylang/module.py index 930c5104..413520ef 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -237,7 +237,7 @@ def register_func_decl( decl = RawFunctionDecl(DefId.fresh(self), f.__name__, None, f, get_py_scope(f)) self.register_def(decl, instance) return decl - + def register_pytket_func( self, f: PyFunc, instance: TypeDef | None = None ) -> RawPytketDef: From dfc7da6df866937f4c4c8e2c858bd6d5586772e3 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Tue, 26 Nov 2024 14:25:10 +0000 Subject: [PATCH 04/21] Implement parse for circuit signature stub --- guppylang/checker/errors/py_errors.py | 9 ++ guppylang/decorator.py | 9 +- guppylang/definition/pytket_circ.py | 73 ------------ guppylang/definition/pytket_circuits.py | 150 ++++++++++++++++++++++++ guppylang/module.py | 8 +- 5 files changed, 168 insertions(+), 81 deletions(-) delete mode 100644 guppylang/definition/pytket_circ.py create mode 100644 guppylang/definition/pytket_circuits.py diff --git a/guppylang/checker/errors/py_errors.py b/guppylang/checker/errors/py_errors.py index e3377888..fb60d75b 100644 --- a/guppylang/checker/errors/py_errors.py +++ b/guppylang/checker/errors/py_errors.py @@ -53,3 +53,12 @@ class Tket2NotInstalled(Error): @dataclass(frozen=True) class InstallInstruction(Help): message: ClassVar[str] = "Install tket2: `pip install tket2`" + + +@dataclass(frozen=True) +class PytketSignatureMismatch(Error): + title: ClassVar[str] = "Signature mismatch" + span_label: ClassVar[str] = ( + "Function signature {name} doesn't match provided pytket circuit" + ) + name: str diff --git a/guppylang/decorator.py b/guppylang/decorator.py index ab91fb2d..d3787a44 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -30,7 +30,7 @@ RawFunctionDef, ) from guppylang.definition.parameter import ConstVarDef, TypeVarDef -from guppylang.definition.pytket_circ import RawPytketDef +from guppylang.definition.pytket_circuits import RawPytketDef from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import MissingModuleError, pretty_errors @@ -468,16 +468,15 @@ def registered_modules(self) -> KeysView[ModuleIdentifier]: """Returns a list of all currently registered modules for local contexts.""" return self._modules.keys() - # TODO: Circuit type and propagation @pretty_errors - def pytket_circ( - self, circuit: Any, name: str = "", module: GuppyModule | None = None + def pytket( + self, input_circuit: Any, module: GuppyModule | None = None ) -> PytketDecorator: """Adds a pytket circuit function definition.""" module = module or self.get_module() def dec(f: PyFunc) -> RawPytketDef: - return module.register_pytket_func(f, circuit) + return module.register_pytket_func(f, input_circuit) return dec diff --git a/guppylang/definition/pytket_circ.py b/guppylang/definition/pytket_circ.py deleted file mode 100644 index 69f4b1a1..00000000 --- a/guppylang/definition/pytket_circ.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - -import hugr.build.function as hf -from hugr.build.dfg import DefinitionBuilder, OpVar - -from guppylang.checker.core import Globals, PyScope -from guppylang.definition.common import ( - CheckableDef, - CompilableDef, - ParsableDef, -) -from guppylang.definition.value import CompiledCallableDef -from guppylang.span import SourceMap - -PyFunc = Callable[..., Any] - - -@dataclass(frozen=True) -class RawPytketDef(ParsableDef): - """A raw function stub definition describing the signature of a circuit. - - Args: - id: The unique definition identifier. - name: The name of the function stub. - defined_at: The AST node where the stub was defined. - python_func: The Python function stub. - python_scope: The Python scope where the function stub was defined. - """ - - python_func: PyFunc - python_scope: PyScope - - description: str = field(default="pytket circuit", init=False) - - def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": - """Parses and checks the user-provided signature matches the - user-provided circuit. - """ - - -@dataclass(frozen=True) -class ParsedPytketDef(CheckableDef, CompilableDef): - """A circuit definition with parsed and checked signature. - - Args: - id: The unique definition identifier. - name: The name of the function. - defined_at: The AST node where the function was defined. - ty: The type of the function. - python_scope: The Python scope where the function was defined. - """ - - description: str = field(default="pytket circuit", init=False) - - def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": - """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" - - -class CompiledPytketDef(CompiledCallableDef): - """A function definition with a corresponding Hugr node. - - Args: - id: The unique definition identifier. - name: The name of the function. - defined_at: The AST node where the function was defined. - ty: The type of the function. - python_scope: The Python scope where the function was defined. - func_df: The Hugr function definition. - """ - - func_def: hf.Function diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py new file mode 100644 index 00000000..fc42be9c --- /dev/null +++ b/guppylang/definition/pytket_circuits.py @@ -0,0 +1,150 @@ +import ast +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, cast + +import hugr.build.function as hf +from hugr.build.dfg import DefinitionBuilder, OpVar + +from guppylang.ast_util import has_empty_body +from guppylang.checker.core import Globals, PyScope +from guppylang.checker.func_checker import check_signature +from guppylang.definition.common import ( + CheckableDef, + CompilableDef, + ParsableDef, +) +from guppylang.definition.declaration import BodyNotEmptyError +from guppylang.definition.function import PyFunc, parse_py_func +from guppylang.definition.ty import TypeDef +from guppylang.definition.value import CompiledCallableDef +from guppylang.span import SourceMap + +from guppylang.checker.errors.py_errors import ( + PytketSignatureMismatch, + Tket2NotInstalled, +) +from guppylang.error import ( + GuppyError, +) +from guppylang.tys.builtin import ( + bool_type, +) +from guppylang.tys.ty import ( + FuncInput, + FunctionType, + InputFlags, + row_to_type, +) + + +@dataclass(frozen=True) +class RawPytketDef(ParsableDef): + """A raw function stub definition describing the signature of a circuit. + + Args: + id: The unique definition identifier. + name: The name of the function stub. + defined_at: The AST node where the stub was defined. + python_func: The Python function stub. + python_scope: The Python scope where the function stub was defined. + input_circuit: The user-provided pytket circuit. + """ + + python_func: PyFunc + python_scope: PyScope + input_circuit: Any + + description: str = field(default="pytket circuit", init=False) + + def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": + """Parses and checks the user-provided signature matches the user-provided + circuit. + """ + # Retrieve stub signature. + func_ast, _ = parse_py_func(self.python_func, sources) + if not has_empty_body(func_ast): + # Function stub should have empty body. + raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name)) + stub_signature = check_signature( + func_ast, globals.with_python_scope(self.python_scope) + ) + + # Retrieve circuit signature and compare. + try: + import pytket + + if isinstance(self.input_circuit, pytket.circuit.Circuit): + try: + import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401 + + qubit = cast(TypeDef, globals["qubit"]).check_instantiate( + [], globals + ) + circuit_signature = FunctionType( + [FuncInput(qubit, InputFlags.Inout)] + * self.input_circuit.n_qubits, + row_to_type([bool_type()] * self.input_circuit.n_bits), + ) + # TODO: Allow arrays in stub signature. + if not ( + circuit_signature.inputs == stub_signature.inputs + and circuit_signature.output == stub_signature.output + ): + raise GuppyError( + PytketSignatureMismatch(self.defined_at, self.name) + ) + + except ImportError: + err = Tket2NotInstalled(self.defined_at) + err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None)) + raise GuppyError(err) from None + except ImportError: + pass + return ParsedPytketDef( + self.id, + self.name, + func_ast, + stub_signature, + self.python_scope, + self.input_circuit, + ) + + +@dataclass(frozen=True) +class ParsedPytketDef(CompilableDef): + """A circuit definition with parsed and checked signature. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + input_circuit: The user-provided pytket circuit. + """ + + defined_at: ast.FunctionDef + ty: FunctionType + python_scope: PyScope + input_circuit: Any + + description: str = field(default="pytket circuit", init=False) + + def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": + """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" + + +class CompiledPytketDef(CompiledCallableDef): + """A function definition with a corresponding Hugr node. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + func_df: The Hugr function definition. + """ + + func_def: hf.Function diff --git a/guppylang/module.py b/guppylang/module.py index 413520ef..c9fc8a48 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -27,7 +27,7 @@ from guppylang.definition.function import RawFunctionDef from guppylang.definition.module import ModuleDef from guppylang.definition.parameter import ParamDef -from guppylang.definition.pytket_circ import RawPytketDef +from guppylang.definition.pytket_circuits import RawPytketDef from guppylang.definition.struct import CheckedStructDef from guppylang.definition.ty import TypeDef from guppylang.error import pretty_errors @@ -239,10 +239,12 @@ def register_func_decl( return decl def register_pytket_func( - self, f: PyFunc, instance: TypeDef | None = None + self, f: PyFunc, input_value: Any, instance: TypeDef | None = None ) -> RawPytketDef: """Registers a pytket circuit function as belonging to this Guppy module.""" - decl = RawPytketDef(DefId.fresh(self), f.__name__, None, f, get_py_scope(f)) + decl = RawPytketDef( + DefId.fresh(self), f.__name__, None, f, get_py_scope(f), input_value + ) self.register_def(decl, instance) return decl From 188aa0cc999c98bfcbed8aab6ef62d5e1611bb5e Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Tue, 26 Nov 2024 17:24:28 +0000 Subject: [PATCH 05/21] Start implementing conversion to hugr function definition --- guppylang/checker/errors/py_errors.py | 7 +++ guppylang/definition/pytket_circuits.py | 75 ++++++++++++++++++++----- 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/guppylang/checker/errors/py_errors.py b/guppylang/checker/errors/py_errors.py index fb60d75b..04eb2af5 100644 --- a/guppylang/checker/errors/py_errors.py +++ b/guppylang/checker/errors/py_errors.py @@ -62,3 +62,10 @@ class PytketSignatureMismatch(Error): "Function signature {name} doesn't match provided pytket circuit" ) name: str + +@dataclass(frozen=True) +class PytketNotCircuit(Error): + title: ClassVar[str] = "Input not circuit" + span_label: ClassVar[str] = ( + "Provided input is not a pytket circuit" + ) diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index fc42be9c..f2383009 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -3,19 +3,24 @@ from dataclasses import dataclass, field from typing import Any, cast +from hugr import Hugr, Wire, ops import hugr.build.function as hf +import hugr.tys as ht from hugr.build.dfg import DefinitionBuilder, OpVar +from guppylang.ast_util import AstNode, with_loc +from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.ast_util import has_empty_body -from guppylang.checker.core import Globals, PyScope +from guppylang.checker.core import Context, Globals, PyScope from guppylang.checker.func_checker import check_signature +from guppylang.compiler.core import CompiledGlobals, DFContainer from guppylang.definition.common import ( - CheckableDef, CompilableDef, ParsableDef, ) +from guppylang.definition.value import CallReturnWires, CallableDef from guppylang.definition.declaration import BodyNotEmptyError -from guppylang.definition.function import PyFunc, parse_py_func +from guppylang.definition.function import CheckedFunctionDef, CompiledFunctionDef, PyFunc, parse_py_func from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef from guppylang.span import SourceMap @@ -23,18 +28,19 @@ from guppylang.checker.errors.py_errors import ( PytketSignatureMismatch, Tket2NotInstalled, + PytketNotCircuit ) -from guppylang.error import ( - GuppyError, -) -from guppylang.tys.builtin import ( - bool_type, -) +from guppylang.error import GuppyError +from guppylang.nodes import GlobalCall +from guppylang.tys.builtin import bool_type +from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( FuncInput, FunctionType, + Type, InputFlags, row_to_type, + type_to_row, ) @@ -92,13 +98,14 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": and circuit_signature.output == stub_signature.output ): raise GuppyError( - PytketSignatureMismatch(self.defined_at, self.name) + PytketSignatureMismatch(func_ast, self.name) ) - except ImportError: - err = Tket2NotInstalled(self.defined_at) + err = Tket2NotInstalled(func_ast) err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None)) raise GuppyError(err) from None + else: + raise GuppyError(PytketNotCircuit(func_ast)) except ImportError: pass return ParsedPytketDef( @@ -112,7 +119,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": @dataclass(frozen=True) -class ParsedPytketDef(CompilableDef): +class ParsedPytketDef(CallableDef, CompilableDef): """A circuit definition with parsed and checked signature. Args: @@ -134,8 +141,45 @@ class ParsedPytketDef(CompilableDef): def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" + try: + import pytket -class CompiledPytketDef(CompiledCallableDef): + if isinstance(self.input_circuit, pytket.circuit.Circuit): + from tket2.circuit import ( # type: ignore[import-untyped, import-not-found, unused-ignore] + Tk2Circuit, + ) + circ = Hugr.load_json(Tk2Circuit(self.input_circuit).to_hugr_json()) # type: ignore[attr-defined, unused-ignore] + print(circ) + else: + raise GuppyError(PytketNotCircuit(self.defined_at)) + except ImportError: + pass + + hugr_func = None + return CompiledPytketDef(self.id, self.name, self.defined_at, self.ty, self.python_scope, None, None, hugr_func) + + + def check_call( + self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context + ) -> tuple[ast.expr, Subst]: + """Checks the return type of a function call against a given type.""" + # Use default implementation from the expression checker + args, subst, inst = check_call(self.ty, args, ty, node, ctx) + node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst)) + return node, subst + + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[ast.expr, Type]: + """Synthesizes the return type of a function call.""" + # Use default implementation from the expression checker + args, ty, inst = synthesize_call(self.ty, args, node, ctx) + node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst)) + return node, ty + + +@dataclass(frozen=True) +class CompiledPytketDef(CompiledFunctionDef, CompiledCallableDef): """A function definition with a corresponding Hugr node. Args: @@ -148,3 +192,6 @@ class CompiledPytketDef(CompiledCallableDef): """ func_def: hf.Function + + def compile_inner(self, globals: CompiledGlobals) -> None: + pass From ce222b0ae0f1b9d882a011cf3c7ea89ca289c9eb Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Tue, 26 Nov 2024 17:24:51 +0000 Subject: [PATCH 06/21] Add basic test --- tests/integration/test_pytket_circuits.py | 38 +++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/integration/test_pytket_circuits.py diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py new file mode 100644 index 00000000..3653665f --- /dev/null +++ b/tests/integration/test_pytket_circuits.py @@ -0,0 +1,38 @@ +"""Tests for loading pytket circuits as functions.""" + +from importlib.util import find_spec + +import pytest + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std import quantum +from guppylang.std.quantum import qubit + +tket2_installed = find_spec("tket2") is not None + + +# @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +def test_single_qubit_circuit(validate): + from pytket import Circuit + + circ = Circuit(1) + circ.H(0) + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q1: qubit) -> None: + """Function stub for circuit""" + + @guppy(module) + def bar(q: qubit) -> None: + pass + + @guppy(module) + def foo(q: qubit) -> None: + bar(q) + guppy_circ(q) + + validate(module.compile()) \ No newline at end of file From 4efe329aa6ec7cddb956b7409363145ac7601e8f Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Wed, 27 Nov 2024 10:58:40 +0000 Subject: [PATCH 07/21] Extract compiled function methods --- guppylang/checker/errors/py_errors.py | 5 +- guppylang/definition/function.py | 51 ++++++++++----- guppylang/definition/pytket_circuits.py | 86 +++++++++++++++++-------- 3 files changed, 96 insertions(+), 46 deletions(-) diff --git a/guppylang/checker/errors/py_errors.py b/guppylang/checker/errors/py_errors.py index 04eb2af5..3158a3a1 100644 --- a/guppylang/checker/errors/py_errors.py +++ b/guppylang/checker/errors/py_errors.py @@ -63,9 +63,8 @@ class PytketSignatureMismatch(Error): ) name: str + @dataclass(frozen=True) class PytketNotCircuit(Error): title: ClassVar[str] = "Input not circuit" - span_label: ClassVar[str] = ( - "Provided input is not a pytket circuit" - ) + span_label: ClassVar[str] = "Provided input is not a pytket circuit" diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index 94f85f19..b77fc601 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -199,9 +199,7 @@ def load_with_args( node: AstNode, ) -> Wire: """Loads the function as a value into a local Hugr dataflow graph.""" - func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() - type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] - return dfg.builder.load_function(self.func_def, func_ty, type_args) + return load_with_args(type_args, dfg, self.ty, self.func_def) def compile_call( self, @@ -212,24 +210,47 @@ def compile_call( node: AstNode, ) -> CallReturnWires: """Compiles a call to the function.""" - func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() - type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] - num_returns = len(type_to_row(self.ty.output)) - call = dfg.builder.call( - self.func_def, *args, instantiation=func_ty, type_args=type_args - ) - return CallReturnWires( - # TODO: Replace below with `list(call[:num_returns])` once - # https://github.com/CQCL/hugr/issues/1454 is fixed. - regular_returns=[call[i] for i in range(num_returns)], - inout_returns=list(call[num_returns:]), - ) + return compile_call(args, type_args, dfg, self.ty, self.func_def) def compile_inner(self, globals: CompiledGlobals) -> None: """Compiles the body of the function.""" compile_global_func_def(self, self.func_def, globals) +def load_with_args( + type_args: Inst, + dfg: DFContainer, + ty: FunctionType, + # TODO: Maybe change to ToNode so this can be used by declarations. + func: hf.Function, +) -> Wire: + """Loads the function as a value into a local Hugr dataflow graph.""" + func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + return dfg.builder.load_function(func, func_ty, type_args) + + +def compile_call( + args: list[Wire], + type_args: Inst, + dfg: DFContainer, + ty: FunctionType, + # TODO: Maybe change to ToNode so this can be used by declarations. + func: hf.Function, +) -> CallReturnWires: + """Compiles a call to the function.""" + func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + num_returns = len(type_to_row(ty.output)) + call = dfg.builder.call(func, *args, instantiation=func_ty, type_args=type_args) + return CallReturnWires( + # TODO: Replace below with `list(call[:num_returns])` once + # https://github.com/CQCL/hugr/issues/1454 is fixed. + regular_returns=[call[i] for i in range(num_returns)], + inout_returns=list(call[num_returns:]), + ) + + def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | None]: source_lines, line_offset = inspect.getsourcelines(f) source = "".join(source_lines) # Lines already have trailing \n's diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index f2383009..2d2f0468 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -1,46 +1,45 @@ import ast -from collections.abc import Callable from dataclasses import dataclass, field from typing import Any, cast -from hugr import Hugr, Wire, ops import hugr.build.function as hf -import hugr.tys as ht +from hugr import Hugr, Wire from hugr.build.dfg import DefinitionBuilder, OpVar -from guppylang.ast_util import AstNode, with_loc -from guppylang.checker.expr_checker import check_call, synthesize_call -from guppylang.ast_util import has_empty_body +from guppylang.ast_util import AstNode, has_empty_body, with_loc from guppylang.checker.core import Context, Globals, PyScope +from guppylang.checker.errors.py_errors import ( + PytketNotCircuit, + PytketSignatureMismatch, + Tket2NotInstalled, +) +from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import check_signature from guppylang.compiler.core import CompiledGlobals, DFContainer from guppylang.definition.common import ( CompilableDef, ParsableDef, ) -from guppylang.definition.value import CallReturnWires, CallableDef from guppylang.definition.declaration import BodyNotEmptyError -from guppylang.definition.function import CheckedFunctionDef, CompiledFunctionDef, PyFunc, parse_py_func -from guppylang.definition.ty import TypeDef -from guppylang.definition.value import CompiledCallableDef -from guppylang.span import SourceMap - -from guppylang.checker.errors.py_errors import ( - PytketSignatureMismatch, - Tket2NotInstalled, - PytketNotCircuit +from guppylang.definition.function import ( + PyFunc, + compile_call, + load_with_args, + parse_py_func, ) +from guppylang.definition.ty import TypeDef +from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef from guppylang.error import GuppyError from guppylang.nodes import GlobalCall +from guppylang.span import SourceMap from guppylang.tys.builtin import bool_type from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( FuncInput, FunctionType, - Type, InputFlags, + Type, row_to_type, - type_to_row, ) @@ -97,9 +96,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": circuit_signature.inputs == stub_signature.inputs and circuit_signature.output == stub_signature.output ): - raise GuppyError( - PytketSignatureMismatch(func_ast, self.name) - ) + raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) except ImportError: err = Tket2NotInstalled(func_ast) err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None)) @@ -141,6 +138,7 @@ class ParsedPytketDef(CallableDef, CompilableDef): def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" + hugr_func = None try: import pytket @@ -148,17 +146,25 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" from tket2.circuit import ( # type: ignore[import-untyped, import-not-found, unused-ignore] Tk2Circuit, ) - circ = Hugr.load_json(Tk2Circuit(self.input_circuit).to_hugr_json()) # type: ignore[attr-defined, unused-ignore] - print(circ) + + hugr_func = Hugr.load_json( + Tk2Circuit(self.input_circuit).to_hugr_json() + ) # type: ignore[attr-defined, unused-ignore] else: raise GuppyError(PytketNotCircuit(self.defined_at)) except ImportError: pass - hugr_func = None - return CompiledPytketDef(self.id, self.name, self.defined_at, self.ty, self.python_scope, None, None, hugr_func) - - + return CompiledPytketDef( + self.id, + self.name, + self.defined_at, + self.ty, + self.python_scope, + self.input_circuit, + hugr_func, + ) + def check_call( self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context ) -> tuple[ast.expr, Subst]: @@ -179,7 +185,7 @@ def synthesize_call( @dataclass(frozen=True) -class CompiledPytketDef(CompiledFunctionDef, CompiledCallableDef): +class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef): """A function definition with a corresponding Hugr node. Args: @@ -188,10 +194,34 @@ class CompiledPytketDef(CompiledFunctionDef, CompiledCallableDef): defined_at: The AST node where the function was defined. ty: The type of the function. python_scope: The Python scope where the function was defined. + input_circuit: The user-provided pytket circuit. func_df: The Hugr function definition. """ func_def: hf.Function + def load_with_args( + self, + type_args: Inst, + dfg: DFContainer, + globals: CompiledGlobals, + node: AstNode, + ) -> Wire: + """Loads the function as a value into a local Hugr dataflow graph.""" + # Use implementation from function definition. + return load_with_args(type_args, dfg, self.ty, self.func_def) + + def compile_call( + self, + args: list[Wire], + type_args: Inst, + dfg: DFContainer, + globals: CompiledGlobals, + node: AstNode, + ) -> CallReturnWires: + """Compiles a call to the function.""" + # Use implementation from function definition. + return compile_call(args, type_args, dfg, self.ty, self.func_def) + def compile_inner(self, globals: CompiledGlobals) -> None: pass From 86c89d0380262025b2bd1e3095f258c6409efd6f Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Wed, 27 Nov 2024 16:02:50 +0000 Subject: [PATCH 08/21] Implement hugr insertion of circuit function --- guppylang/definition/pytket_circuits.py | 16 ++++++------ tests/integration/test_pytket_circuits.py | 30 +++++++++++++++++------ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index 2d2f0468..23e756ca 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -92,10 +92,8 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": row_to_type([bool_type()] * self.input_circuit.n_bits), ) # TODO: Allow arrays in stub signature. - if not ( - circuit_signature.inputs == stub_signature.inputs - and circuit_signature.output == stub_signature.output - ): + # TODO: Comparing outputs? + if not circuit_signature.inputs == stub_signature.inputs: raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) except ImportError: err = Tket2NotInstalled(func_ast) @@ -137,8 +135,8 @@ class ParsedPytketDef(CallableDef, CompilableDef): def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" - hugr_func = None + try: import pytket @@ -147,9 +145,11 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" Tk2Circuit, ) - hugr_func = Hugr.load_json( - Tk2Circuit(self.input_circuit).to_hugr_json() - ) # type: ignore[attr-defined, unused-ignore] + circ = Hugr.load_json(Tk2Circuit(self.input_circuit).to_hugr_json()) # type: ignore[attr-defined, unused-ignore] + + mapping = module.hugr.insert_hugr(circ) + hugr_func = mapping[circ.root] + else: raise GuppyError(PytketNotCircuit(self.defined_at)) except ImportError: diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index 3653665f..ba0e36fb 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -12,7 +12,7 @@ tket2_installed = find_spec("tket2") is not None -# @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") def test_single_qubit_circuit(validate): from pytket import Circuit @@ -23,16 +23,30 @@ def test_single_qubit_circuit(validate): module.load_all(quantum) @guppy.pytket(circ, module) - def guppy_circ(q1: qubit) -> None: - """Function stub for circuit""" - - @guppy(module) - def bar(q: qubit) -> None: - pass + def guppy_circ(q1: qubit) -> None: ... @guppy(module) def foo(q: qubit) -> None: - bar(q) guppy_circ(q) + validate(module.compile()) + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +def test_multi_qubit_circuit(validate): + from pytket import Circuit + + circ = Circuit(2) + circ.H(0) + circ.CX(0, 1) + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q1: qubit, q2: qubit) -> None: ... + + @guppy(module) + def foo(q1: qubit, q2: qubit) -> None: + guppy_circ(q1, q2) + validate(module.compile()) \ No newline at end of file From 64176fcb089774865ffc50ec27805c45f61b5b2f Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Thu, 28 Nov 2024 08:54:28 +0000 Subject: [PATCH 09/21] Add false positive typo check fix --- .typos.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/.typos.toml b/.typos.toml index 26e3321f..acf2b12c 100644 --- a/.typos.toml +++ b/.typos.toml @@ -4,3 +4,4 @@ ine = "ine" inot = "inot" inout = "inout" inouts = "inouts" +anc = "anc" From 273600d319de36b66fb9bedecd51bee8b0272ffd Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Thu, 28 Nov 2024 10:33:36 +0000 Subject: [PATCH 10/21] Change hugr function name and add more tests --- guppylang/decorator.py | 2 +- guppylang/definition/pytket_circuits.py | 17 ++++-- tests/integration/test_pytket_circuits.py | 66 +++++++++++++++++++++++ 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index d3787a44..e5992abe 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -472,7 +472,7 @@ def registered_modules(self) -> KeysView[ModuleIdentifier]: def pytket( self, input_circuit: Any, module: GuppyModule | None = None ) -> PytketDecorator: - """Adds a pytket circuit function definition.""" + """Adds a pytket circuit function definition with explicit signature.""" module = module or self.get_module() def dec(f: PyFunc) -> RawPytketDef: diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index 23e756ca..99d36158 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -2,8 +2,7 @@ from dataclasses import dataclass, field from typing import Any, cast -import hugr.build.function as hf -from hugr import Hugr, Wire +from hugr import Hugr, Node, Wire from hugr.build.dfg import DefinitionBuilder, OpVar from guppylang.ast_util import AstNode, has_empty_body, with_loc @@ -75,6 +74,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": func_ast, globals.with_python_scope(self.python_scope) ) + # TODO: Allow arrays as arguments. # Retrieve circuit signature and compare. try: import pytket @@ -91,8 +91,9 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": * self.input_circuit.n_qubits, row_to_type([bool_type()] * self.input_circuit.n_bits), ) - # TODO: Allow arrays in stub signature. - # TODO: Comparing outputs? + # Note this doesn't set the output type. + print(circuit_signature.inputs) + print(stub_signature.inputs) if not circuit_signature.inputs == stub_signature.inputs: raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) except ImportError: @@ -103,6 +104,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": raise GuppyError(PytketNotCircuit(func_ast)) except ImportError: pass + return ParsedPytketDef( self.id, self.name, @@ -150,6 +152,11 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" mapping = module.hugr.insert_hugr(circ) hugr_func = mapping[circ.root] + # TODO: Check that node data op is FuncDefn. + node_data = module.hugr.get(hugr_func) + node_data.op.f_name = self.name + print(module.hugr.get(hugr_func)) + else: raise GuppyError(PytketNotCircuit(self.defined_at)) except ImportError: @@ -198,7 +205,7 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef): func_df: The Hugr function definition. """ - func_def: hf.Function + func_def: Node def load_with_args( self, diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index ba0e36fb..8fefd83b 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -31,6 +31,7 @@ def foo(q: qubit) -> None: validate(module.compile()) + @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") def test_multi_qubit_circuit(validate): from pytket import Circuit @@ -49,4 +50,69 @@ def guppy_circ(q1: qubit, q2: qubit) -> None: ... def foo(q1: qubit, q2: qubit) -> None: guppy_circ(q1, q2) + validate(module.compile()) + + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("Classical bits lead to port error") +def test_classic_bits(validate): + from pytket import Circuit + + circ = Circuit(2, 2) + circ.H(0) + circ.CX(0, 1) + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q1: qubit, q2: qubit) -> None: ... + + @guppy(module) + def foo(q1: qubit, q2: qubit) -> None: + guppy_circ(q1, q2) + + validate(module.compile()) + + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("Measuring leads to port error") +def test_measure_circuit(validate): + from pytket import Circuit + + circ = Circuit(2) + circ.H(0) + circ.CX(0, 1) + circ.measure_all() + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q1: qubit, q2: qubit) -> None: ... + + @guppy(module) + def foo(q1: qubit, q2: qubit) -> None: + guppy_circ(q1, q2) + + validate(module.compile()) + + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("Not implemented") +def test_load_circuit(validate): + from pytket import Circuit + + circ = Circuit(1) + circ.H(0) + + module = GuppyModule("test") + module.load_all(quantum) + + guppy_circ = guppy.load_pytket("guppy_circ", circ, module) + + @guppy(module) + def foo(q: qubit) -> None: + guppy_circ(q) + validate(module.compile()) \ No newline at end of file From 787d515fefc6094f4e82ba283a58c5f731e10043 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Thu, 28 Nov 2024 16:28:02 +0000 Subject: [PATCH 11/21] Fix typing errors --- guppylang/decorator.py | 8 ++--- guppylang/definition/declaration.py | 29 +++++++---------- guppylang/definition/function.py | 5 +-- guppylang/definition/pytket_circuits.py | 14 +++----- guppylang/module.py | 6 ++-- tests/integration/test_pytket_circuits.py | 39 +++++------------------ 6 files changed, 33 insertions(+), 68 deletions(-) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index e5992abe..e85a941e 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -473,12 +473,12 @@ def pytket( self, input_circuit: Any, module: GuppyModule | None = None ) -> PytketDecorator: """Adds a pytket circuit function definition with explicit signature.""" - module = module or self.get_module() + mod = module or self.get_module() - def dec(f: PyFunc) -> RawPytketDef: - return module.register_pytket_func(f, input_circuit) + def func(f: PyFunc) -> RawPytketDef: + return mod.register_pytket_func(f, input_circuit) - return dec + return func class _GuppyDummy: diff --git a/guppylang/definition/declaration.py b/guppylang/definition/declaration.py index c35f3e73..57f812cf 100644 --- a/guppylang/definition/declaration.py +++ b/guppylang/definition/declaration.py @@ -3,7 +3,6 @@ from typing import ClassVar from hugr import Node, Wire -from hugr import tys as ht from hugr.build import function as hf from hugr.build.dfg import DefinitionBuilder, OpVar @@ -13,14 +12,19 @@ from guppylang.checker.func_checker import check_signature from guppylang.compiler.core import CompiledGlobals, DFContainer from guppylang.definition.common import CompilableDef, ParsableDef -from guppylang.definition.function import PyFunc, parse_py_func +from guppylang.definition.function import ( + PyFunc, + compile_call, + load_with_args, + parse_py_func, +) from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef from guppylang.diagnostic import Error from guppylang.error import GuppyError from guppylang.nodes import GlobalCall from guppylang.span import SourceMap from guppylang.tys.subst import Inst, Subst -from guppylang.tys.ty import Type, type_to_row +from guppylang.tys.ty import Type @dataclass(frozen=True) @@ -121,9 +125,8 @@ def load_with_args( node: AstNode, ) -> Wire: """Loads the function as a value into a local Hugr dataflow graph.""" - func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() - type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] - return dfg.builder.load_function(self.declaration, func_ty, type_args) + # Use implementation from function definition. + return load_with_args(type_args, dfg, self.ty, self.declaration) def compile_call( self, @@ -134,15 +137,5 @@ def compile_call( node: AstNode, ) -> CallReturnWires: """Compiles a call to the function.""" - func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() - type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] - num_returns = len(type_to_row(self.ty.output)) - call = dfg.builder.call( - self.declaration, *args, instantiation=func_ty, type_args=type_args - ) - return CallReturnWires( - # TODO: Replace below with `list(call[:num_returns])` once - # https://github.com/CQCL/hugr/issues/1454 is fixed. - regular_returns=[call[i] for i in range(num_returns)], - inout_returns=list(call[num_returns:]), - ) + # Use implementation from function definition. + return compile_call(args, type_args, dfg, self.ty, self.declaration) diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index b77fc601..156f3b2f 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -9,6 +9,7 @@ import hugr.tys as ht from hugr import Wire from hugr.build.dfg import DefinitionBuilder, OpVar +from hugr.hugr.node_port import ToNode from hugr.package import FuncDefnPointer from guppylang.ast_util import AstNode, annotate_location, with_loc @@ -222,7 +223,7 @@ def load_with_args( dfg: DFContainer, ty: FunctionType, # TODO: Maybe change to ToNode so this can be used by declarations. - func: hf.Function, + func: ToNode, ) -> Wire: """Loads the function as a value into a local Hugr dataflow graph.""" func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr() @@ -236,7 +237,7 @@ def compile_call( dfg: DFContainer, ty: FunctionType, # TODO: Maybe change to ToNode so this can be used by declarations. - func: hf.Function, + func: ToNode, ) -> CallReturnWires: """Compiles a call to the function.""" func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr() diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index 99d36158..fc6918ed 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -4,6 +4,7 @@ from hugr import Hugr, Node, Wire from hugr.build.dfg import DefinitionBuilder, OpVar +from hugr.ops import FuncDefn from guppylang.ast_util import AstNode, has_empty_body, with_loc from guppylang.checker.core import Context, Globals, PyScope @@ -92,8 +93,6 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": row_to_type([bool_type()] * self.input_circuit.n_bits), ) # Note this doesn't set the output type. - print(circuit_signature.inputs) - print(stub_signature.inputs) if not circuit_signature.inputs == stub_signature.inputs: raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) except ImportError: @@ -137,8 +136,6 @@ class ParsedPytketDef(CallableDef, CompilableDef): def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef": """Adds a Hugr `FuncDefn` node for this function to the Hugr.""" - hugr_func = None - try: import pytket @@ -152,10 +149,10 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" mapping = module.hugr.insert_hugr(circ) hugr_func = mapping[circ.root] - # TODO: Check that node data op is FuncDefn. node_data = module.hugr.get(hugr_func) - node_data.op.f_name = self.name - print(module.hugr.get(hugr_func)) + # TODO: Handle case if it isn't. + if isinstance(node_data, FuncDefn): + node_data.op.f_name = self.name else: raise GuppyError(PytketNotCircuit(self.defined_at)) @@ -229,6 +226,3 @@ def compile_call( """Compiles a call to the function.""" # Use implementation from function definition. return compile_call(args, type_args, dfg, self.ty, self.func_def) - - def compile_inner(self, globals: CompiledGlobals) -> None: - pass diff --git a/guppylang/module.py b/guppylang/module.py index c9fc8a48..f51fab70 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -242,11 +242,11 @@ def register_pytket_func( self, f: PyFunc, input_value: Any, instance: TypeDef | None = None ) -> RawPytketDef: """Registers a pytket circuit function as belonging to this Guppy module.""" - decl = RawPytketDef( + func = RawPytketDef( DefId.fresh(self), f.__name__, None, f, get_py_scope(f), input_value ) - self.register_def(decl, instance) - return decl + self.register_def(func, instance) + return func def _register_buffered_instance_funcs(self, instance: TypeDef) -> None: assert self._instance_func_buffer is not None diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index 8fefd83b..972fa5e0 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -54,46 +54,23 @@ def foo(q1: qubit, q2: qubit) -> None: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("Classical bits lead to port error") -def test_classic_bits(validate): - from pytket import Circuit - - circ = Circuit(2, 2) - circ.H(0) - circ.CX(0, 1) - - module = GuppyModule("test") - module.load_all(quantum) - - @guppy.pytket(circ, module) - def guppy_circ(q1: qubit, q2: qubit) -> None: ... - - @guppy(module) - def foo(q1: qubit, q2: qubit) -> None: - guppy_circ(q1, q2) - - validate(module.compile()) - - -@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("Measuring leads to port error") +@pytest.mark.skip("Not implemented") def test_measure_circuit(validate): from pytket import Circuit - circ = Circuit(2) - circ.H(0) - circ.CX(0, 1) - circ.measure_all() + circ = Circuit(1, 1) + circ.H(0) + circ.measure_all() module = GuppyModule("test") module.load_all(quantum) @guppy.pytket(circ, module) - def guppy_circ(q1: qubit, q2: qubit) -> None: ... + def guppy_circ(q: qubit) -> None: ... @guppy(module) - def foo(q1: qubit, q2: qubit) -> None: - guppy_circ(q1, q2) + def foo(q: qubit) -> None: + guppy_circ(q) validate(module.compile()) @@ -109,7 +86,7 @@ def test_load_circuit(validate): module = GuppyModule("test") module.load_all(quantum) - guppy_circ = guppy.load_pytket("guppy_circ", circ, module) + guppy.load_pytket("guppy_circ", circ, module) @guppy(module) def foo(q: qubit) -> None: From 8164111d39d742ba4b30b331513188f572030041 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Fri, 29 Nov 2024 09:53:26 +0000 Subject: [PATCH 12/21] Fix check and adjust test --- guppylang/definition/pytket_circuits.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index fc6918ed..d34ca9ea 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -150,9 +150,10 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" hugr_func = mapping[circ.root] node_data = module.hugr.get(hugr_func) - # TODO: Handle case if it isn't. - if isinstance(node_data, FuncDefn): - node_data.op.f_name = self.name + + if node_data and isinstance(node_data.op, FuncDefn): + func_def = node_data.op + func_def.f_name = self.name else: raise GuppyError(PytketNotCircuit(self.defined_at)) From 994e0acc9ea6e2492f7065cdd47de5187e6e7484 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Fri, 29 Nov 2024 09:53:43 +0000 Subject: [PATCH 13/21] Adjust test --- tests/integration/test_pytket_circuits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index 972fa5e0..d84e53dd 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -66,10 +66,10 @@ def test_measure_circuit(validate): module.load_all(quantum) @guppy.pytket(circ, module) - def guppy_circ(q: qubit) -> None: ... + def guppy_circ(q: qubit) -> bool: ... @guppy(module) - def foo(q: qubit) -> None: + def foo(q: qubit) -> bool: guppy_circ(q) validate(module.compile()) From f25f702fa501a183dc7964894dc98b0218e4cb6f Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Mon, 2 Dec 2024 17:32:29 +0000 Subject: [PATCH 14/21] Attempts to manipulate function signature to make measurements work --- guppylang/definition/function.py | 2 - guppylang/definition/pytket_circuits.py | 41 ++++++++++--- tests/integration/test_pytket_circuits.py | 75 ++++++++++++++++++++++- 3 files changed, 106 insertions(+), 12 deletions(-) diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index 156f3b2f..25f0310c 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -222,7 +222,6 @@ def load_with_args( type_args: Inst, dfg: DFContainer, ty: FunctionType, - # TODO: Maybe change to ToNode so this can be used by declarations. func: ToNode, ) -> Wire: """Loads the function as a value into a local Hugr dataflow graph.""" @@ -236,7 +235,6 @@ def compile_call( type_args: Inst, dfg: DFContainer, ty: FunctionType, - # TODO: Maybe change to ToNode so this can be used by declarations. func: ToNode, ) -> CallReturnWires: """Compiles a call to the function.""" diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index d34ca9ea..b363aab3 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -2,9 +2,11 @@ from dataclasses import dataclass, field from typing import Any, cast +import hugr.tys as ht from hugr import Hugr, Node, Wire from hugr.build.dfg import DefinitionBuilder, OpVar from hugr.ops import FuncDefn +from hugr.tys import Bool from guppylang.ast_util import AstNode, has_empty_body, with_loc from guppylang.checker.core import Context, Globals, PyScope @@ -23,7 +25,6 @@ from guppylang.definition.declaration import BodyNotEmptyError from guppylang.definition.function import ( PyFunc, - compile_call, load_with_args, parse_py_func, ) @@ -40,6 +41,7 @@ InputFlags, Type, row_to_type, + type_to_row, ) @@ -87,13 +89,17 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": qubit = cast(TypeDef, globals["qubit"]).check_instantiate( [], globals ) + circuit_signature = FunctionType( [FuncInput(qubit, InputFlags.Inout)] * self.input_circuit.n_qubits, row_to_type([bool_type()] * self.input_circuit.n_bits), ) - # Note this doesn't set the output type. - if not circuit_signature.inputs == stub_signature.inputs: + + if not ( + circuit_signature.inputs == stub_signature.inputs + and circuit_signature.output == stub_signature.output + ): raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) except ImportError: err = Tket2NotInstalled(func_ast) @@ -151,9 +157,14 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" node_data = module.hugr.get(hugr_func) + # TODO: Error if hugr isn't FuncDefn? if node_data and isinstance(node_data.op, FuncDefn): - func_def = node_data.op - func_def.f_name = self.name + func_defn = node_data.op + func_defn.f_name = self.name + + num_bools = func_defn.inputs.count(Bool) + for _ in range(num_bools): + func_defn.inputs.remove(Bool) else: raise GuppyError(PytketNotCircuit(self.defined_at)) @@ -225,5 +236,21 @@ def compile_call( node: AstNode, ) -> CallReturnWires: """Compiles a call to the function.""" - # Use implementation from function definition. - return compile_call(args, type_args, dfg, self.ty, self.func_def) + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + num_returns = len(type_to_row(self.ty.output)) + call = dfg.builder.call( + self.func_def, *args, instantiation=func_ty, type_args=type_args + ) + # Negative index slicing doesn't work when num_returns is 0. + if num_returns == 0: + return CallReturnWires( + regular_returns=[], + inout_returns=list(call[0:]), + ) + else: + # Circuit function returns are the other way round to Guppy returns. + return CallReturnWires( + regular_returns=list(call[-num_returns:]), + inout_returns=list(call[:-num_returns]), + ) diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index d84e53dd..29dba330 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -53,14 +53,59 @@ def foo(q1: qubit, q2: qubit) -> None: validate(module.compile()) +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +def test_measure(validate): + from pytket import Circuit + + circ = Circuit(1) + circ.H(0) + circ.measure_all() + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q: qubit) -> bool: ... + + @guppy(module) + def foo(q: qubit) -> None: + result = guppy_circ(q) + + print(module.compile_hugr().to_json()) + validate(module.compile()) + + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("Not implemented") +def test_measure_multiple(validate): + from pytket import Circuit + + circ = Circuit(2, 2) + circ.H(0) + circ.measure_all() + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q1: qubit, q2: qubit) -> tuple[bool, bool]: ... + + @guppy(module) + def foo(q1: qubit, q2: qubit) -> None: + result = guppy_circ(q1, q2) + + validate(module.compile()) + + @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") @pytest.mark.skip("Not implemented") -def test_measure_circuit(validate): +def test_measure_not_last(validate): from pytket import Circuit circ = Circuit(1, 1) circ.H(0) circ.measure_all() + circ.X(0) module = GuppyModule("test") module.load_all(quantum) @@ -69,8 +114,32 @@ def test_measure_circuit(validate): def guppy_circ(q: qubit) -> bool: ... @guppy(module) - def foo(q: qubit) -> bool: - guppy_circ(q) + def foo(q: qubit) -> None: + result = guppy_circ(q) + + validate(module.compile()) + + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("Not implemented") +def test_measure_some(validate): + from pytket import Circuit + + circ = Circuit(2, 2) + circ.H(0) + circ.Rz(0.25, 0) + circ.CX(1, 0) + circ.measure_register(0) + + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.pytket(circ, module) + def guppy_circ(q1: qubit, q2: qubit) -> bool: ... + + @guppy(module) + def foo(q1: qubit, q2: qubit) -> None: + result = guppy_circ(q1, q2) validate(module.compile()) From c72c519bb4e9b23f5e9b00f1651d55c8ff4b71bc Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Tue, 3 Dec 2024 13:15:52 +0000 Subject: [PATCH 15/21] Modify input node to match function definition --- guppylang/definition/pytket_circuits.py | 14 ++++++++---- tests/integration/test_pytket_circuits.py | 27 ----------------------- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index b363aab3..0d10bf11 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -5,7 +5,7 @@ import hugr.tys as ht from hugr import Hugr, Node, Wire from hugr.build.dfg import DefinitionBuilder, OpVar -from hugr.ops import FuncDefn +from hugr.ops import FuncDefn, Input from hugr.tys import Bool from guppylang.ast_util import AstNode, has_empty_body, with_loc @@ -151,21 +151,27 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" ) circ = Hugr.load_json(Tk2Circuit(self.input_circuit).to_hugr_json()) # type: ignore[attr-defined, unused-ignore] - mapping = module.hugr.insert_hugr(circ) hugr_func = mapping[circ.root] + # We need to remove input bits from both signature and input node. node_data = module.hugr.get(hugr_func) - # TODO: Error if hugr isn't FuncDefn? if node_data and isinstance(node_data.op, FuncDefn): func_defn = node_data.op func_defn.f_name = self.name - num_bools = func_defn.inputs.count(Bool) for _ in range(num_bools): func_defn.inputs.remove(Bool) + for child in module.hugr.children(hugr_func): + node_data = module.hugr.get(child) + if node_data and isinstance(node_data.op, Input): + input_types = node_data.op.types + num_bools = input_types.count(Bool) + for _ in range(num_bools): + input_types.remove(Bool) + else: raise GuppyError(PytketNotCircuit(self.defined_at)) except ImportError: diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index 29dba330..aea4fa6b 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -71,12 +71,10 @@ def guppy_circ(q: qubit) -> bool: ... def foo(q: qubit) -> None: result = guppy_circ(q) - print(module.compile_hugr().to_json()) validate(module.compile()) @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("Not implemented") def test_measure_multiple(validate): from pytket import Circuit @@ -98,7 +96,6 @@ def foo(q1: qubit, q2: qubit) -> None: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("Not implemented") def test_measure_not_last(validate): from pytket import Circuit @@ -120,30 +117,6 @@ def foo(q: qubit) -> None: validate(module.compile()) -@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("Not implemented") -def test_measure_some(validate): - from pytket import Circuit - - circ = Circuit(2, 2) - circ.H(0) - circ.Rz(0.25, 0) - circ.CX(1, 0) - circ.measure_register(0) - - module = GuppyModule("test") - module.load_all(quantum) - - @guppy.pytket(circ, module) - def guppy_circ(q1: qubit, q2: qubit) -> bool: ... - - @guppy(module) - def foo(q1: qubit, q2: qubit) -> None: - result = guppy_circ(q1, q2) - - validate(module.compile()) - - @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") @pytest.mark.skip("Not implemented") def test_load_circuit(validate): From ab20ac58d09efbc4cd7c85bb853ec37a87f98fdc Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Wed, 4 Dec 2024 15:59:57 +0000 Subject: [PATCH 16/21] Address comments --- .github/workflows/pull-request.yaml | 2 +- guppylang/checker/errors/py_errors.py | 10 +--- guppylang/decorator.py | 10 ++++ guppylang/definition/pytket_circuits.py | 68 +++++++++-------------- tests/error/py_errors/sig_mismatch.err | 9 +++ tests/error/py_errors/sig_mismatch.py | 22 ++++++++ tests/integration/test_pytket_circuits.py | 15 +++-- 7 files changed, 78 insertions(+), 58 deletions(-) create mode 100644 tests/error/py_errors/sig_mismatch.err create mode 100644 tests/error/py_errors/sig_mismatch.py diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 76bca634..0444e694 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -65,7 +65,7 @@ jobs: run: uv sync --extra pytket - name: Rerun `py(...)` expression tests and pytket lowering with tket2 installed - run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py + run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py /tests/integration/test_pytket_circuits.py test-coverage: name: Check Python (3.13) with coverage diff --git a/guppylang/checker/errors/py_errors.py b/guppylang/checker/errors/py_errors.py index 3158a3a1..1c5a984e 100644 --- a/guppylang/checker/errors/py_errors.py +++ b/guppylang/checker/errors/py_errors.py @@ -58,13 +58,5 @@ class InstallInstruction(Help): @dataclass(frozen=True) class PytketSignatureMismatch(Error): title: ClassVar[str] = "Signature mismatch" - span_label: ClassVar[str] = ( - "Function signature {name} doesn't match provided pytket circuit" - ) + span_label: ClassVar[str] = "Signature {name} doesn't match provided pytket circuit" name: str - - -@dataclass(frozen=True) -class PytketNotCircuit(Error): - title: ClassVar[str] = "Input not circuit" - span_label: ClassVar[str] = "Provided input is not a pytket circuit" diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 74f4309d..b0ea43fe 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -475,6 +475,16 @@ def pytket( self, input_circuit: Any, module: GuppyModule | None = None ) -> PytketDecorator: """Adds a pytket circuit function definition with explicit signature.""" + err_msg = "Only pytket circuits can be passed to guppy.pytket" + try: + import pytket + + if not isinstance(input_circuit, pytket.circuit.Circuit): + raise TypeError(err_msg) from None + + except ImportError: + raise TypeError(err_msg) from None + mod = module or self.get_module() def func(f: PyFunc) -> RawPytketDef: diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index 0d10bf11..df5a377f 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -2,16 +2,13 @@ from dataclasses import dataclass, field from typing import Any, cast -import hugr.tys as ht -from hugr import Hugr, Node, Wire +import hugr.build.function as hf +from hugr import Hugr, OutPort, Wire, ops, tys, val from hugr.build.dfg import DefinitionBuilder, OpVar -from hugr.ops import FuncDefn, Input -from hugr.tys import Bool from guppylang.ast_util import AstNode, has_empty_body, with_loc from guppylang.checker.core import Context, Globals, PyScope from guppylang.checker.errors.py_errors import ( - PytketNotCircuit, PytketSignatureMismatch, Tket2NotInstalled, ) @@ -25,6 +22,7 @@ from guppylang.definition.declaration import BodyNotEmptyError from guppylang.definition.function import ( PyFunc, + compile_call, load_with_args, parse_py_func, ) @@ -41,7 +39,6 @@ InputFlags, Type, row_to_type, - type_to_row, ) @@ -100,13 +97,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": circuit_signature.inputs == stub_signature.inputs and circuit_signature.output == stub_signature.output ): + # TODO: Implement pretty-printing for signatures in order to add + # a note for expected vs. actual types. raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) except ImportError: err = Tket2NotInstalled(func_ast) err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None)) raise GuppyError(err) from None - else: - raise GuppyError(PytketNotCircuit(func_ast)) except ImportError: pass @@ -154,26 +151,29 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" mapping = module.hugr.insert_hugr(circ) hugr_func = mapping[circ.root] - # We need to remove input bits from both signature and input node. - node_data = module.hugr.get(hugr_func) - # TODO: Error if hugr isn't FuncDefn? - if node_data and isinstance(node_data.op, FuncDefn): - func_defn = node_data.op - func_defn.f_name = self.name - num_bools = func_defn.inputs.count(Bool) - for _ in range(num_bools): - func_defn.inputs.remove(Bool) + func_type = self.ty.to_hugr_poly() + outer_func = module.define_function( + self.name, func_type.body.input, func_type.body.output + ) + # Initialise every input bit in the circuit as false. + # TODO: Provide the option for the user to pass this input as well. + bool_wires: list[OutPort] = [] for child in module.hugr.children(hugr_func): node_data = module.hugr.get(child) - if node_data and isinstance(node_data.op, Input): + if node_data and isinstance(node_data.op, ops.Input): input_types = node_data.op.types - num_bools = input_types.count(Bool) + num_bools = input_types.count(tys.Bool) for _ in range(num_bools): - input_types.remove(Bool) + bool_node = outer_func.load(val.FALSE) + bool_wires.append(*bool_node.outputs()) + + call_node = outer_func.call( + hugr_func, *(list(outer_func.inputs()) + bool_wires) + ) + # Pytket circuit hugr has qubit and bool wires in the opposite order. + outer_func.set_outputs(*reversed(list(call_node.outputs()))) - else: - raise GuppyError(PytketNotCircuit(self.defined_at)) except ImportError: pass @@ -184,7 +184,7 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" self.ty, self.python_scope, self.input_circuit, - hugr_func, + outer_func, ) def check_call( @@ -220,7 +220,7 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef): func_df: The Hugr function definition. """ - func_def: Node + func_def: hf.Function def load_with_args( self, @@ -242,21 +242,5 @@ def compile_call( node: AstNode, ) -> CallReturnWires: """Compiles a call to the function.""" - func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() - type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] - num_returns = len(type_to_row(self.ty.output)) - call = dfg.builder.call( - self.func_def, *args, instantiation=func_ty, type_args=type_args - ) - # Negative index slicing doesn't work when num_returns is 0. - if num_returns == 0: - return CallReturnWires( - regular_returns=[], - inout_returns=list(call[0:]), - ) - else: - # Circuit function returns are the other way round to Guppy returns. - return CallReturnWires( - regular_returns=list(call[-num_returns:]), - inout_returns=list(call[:-num_returns]), - ) + # Use implementation from function definition. + return compile_call(args, type_args, dfg, self.ty, self.func_def) diff --git a/tests/error/py_errors/sig_mismatch.err b/tests/error/py_errors/sig_mismatch.err new file mode 100644 index 00000000..2c107dfc --- /dev/null +++ b/tests/error/py_errors/sig_mismatch.err @@ -0,0 +1,9 @@ +Error: Signature mismatch (at :15:0) + | +13 | +14 | @guppy.pytket(circ, module) +15 | def guppy_circ(q: qubit) -> None: ... + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Function signature guppy_circ doesn't match provided pytket + | circuit + +Guppy compilation failed due to 1 previous error \ No newline at end of file diff --git a/tests/error/py_errors/sig_mismatch.py b/tests/error/py_errors/sig_mismatch.py new file mode 100644 index 00000000..ec3f268d --- /dev/null +++ b/tests/error/py_errors/sig_mismatch.py @@ -0,0 +1,22 @@ +from pytket import Circuit + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit + +circ = Circuit(2) +circ.X(0) +circ.Y(1) + +module = GuppyModule("test") +module.load(qubit) + +@guppy.pytket(circ, module) +def guppy_circ(q: qubit) -> None: ... + +@guppy(module) +def foo(q: qubit) -> None: + guppy_circ(q) + + +module.compile() \ No newline at end of file diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index aea4fa6b..e117af7e 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -33,6 +33,7 @@ def foo(q: qubit) -> None: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("remove this") def test_multi_qubit_circuit(validate): from pytket import Circuit @@ -68,13 +69,14 @@ def test_measure(validate): def guppy_circ(q: qubit) -> bool: ... @guppy(module) - def foo(q: qubit) -> None: - result = guppy_circ(q) + def foo(q: qubit) -> bool: + return guppy_circ(q) validate(module.compile()) @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("remove this") def test_measure_multiple(validate): from pytket import Circuit @@ -89,13 +91,14 @@ def test_measure_multiple(validate): def guppy_circ(q1: qubit, q2: qubit) -> tuple[bool, bool]: ... @guppy(module) - def foo(q1: qubit, q2: qubit) -> None: - result = guppy_circ(q1, q2) + def foo(q1: qubit, q2: qubit) -> tuple[bool, bool]: + return guppy_circ(q1, q2) validate(module.compile()) @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +@pytest.mark.skip("remove this") def test_measure_not_last(validate): from pytket import Circuit @@ -111,8 +114,8 @@ def test_measure_not_last(validate): def guppy_circ(q: qubit) -> bool: ... @guppy(module) - def foo(q: qubit) -> None: - result = guppy_circ(q) + def foo(q: qubit) -> bool: + return guppy_circ(q) validate(module.compile()) From 1d417e5416280eba31717ab0ff915aeeb598ea4a Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Wed, 4 Dec 2024 16:22:04 +0000 Subject: [PATCH 17/21] Fix sig mismatch error test --- .github/workflows/pull-request.yaml | 2 +- tests/error/py_errors/sig_mismatch.err | 7 +++---- tests/error/test_py_errors.py | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 0444e694..27dfe962 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -65,7 +65,7 @@ jobs: run: uv sync --extra pytket - name: Rerun `py(...)` expression tests and pytket lowering with tket2 installed - run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py /tests/integration/test_pytket_circuits.py + run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py tests/integration/test_pytket_circuits.py test-coverage: name: Check Python (3.13) with coverage diff --git a/tests/error/py_errors/sig_mismatch.err b/tests/error/py_errors/sig_mismatch.err index 2c107dfc..ba5c8d31 100644 --- a/tests/error/py_errors/sig_mismatch.err +++ b/tests/error/py_errors/sig_mismatch.err @@ -1,9 +1,8 @@ -Error: Signature mismatch (at :15:0) +Error: Signature mismatch (at $FILE:15:0) | 13 | 14 | @guppy.pytket(circ, module) 15 | def guppy_circ(q: qubit) -> None: ... - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Function signature guppy_circ doesn't match provided pytket - | circuit + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Signature guppy_circ doesn't match provided pytket circuit -Guppy compilation failed due to 1 previous error \ No newline at end of file +Guppy compilation failed due to 1 previous error diff --git a/tests/error/test_py_errors.py b/tests/error/test_py_errors.py index b5c0b285..bdb6924c 100644 --- a/tests/error/test_py_errors.py +++ b/tests/error/test_py_errors.py @@ -23,6 +23,7 @@ @pytest.mark.parametrize("file", files) +@pytest.mark.skipif(not tket2_installed, reason="tket2 is not installed") def test_py_errors(file, capsys, snapshot): run_error_test(file, capsys, snapshot) From d41198b197312b6e4f29c7b22b47374e5479b4ae Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Wed, 4 Dec 2024 17:07:03 +0000 Subject: [PATCH 18/21] Add tket2 dependency --- pyproject.toml | 2 +- uv.lock | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7b501a58..e018f49f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ homepage = "https://github.com/CQCL/guppylang" repository = "https://github.com/CQCL/guppylang" [dependency-groups] -# Default dev dependency group dev = [ { include-group = "lint" }, { include-group = "test" }, @@ -67,6 +66,7 @@ test = [ "pytest-notebook >=0.10.0,<0.11", "pytest-snapshot >=0.9.0,<1", "ipykernel >=6.29.5,<7", + "tket2>=0.5.0", ] llvm_integration = [ { include-group = "test" }, diff --git a/uv.lock b/uv.lock index 4d95e9b3..4c14daab 100644 --- a/uv.lock +++ b/uv.lock @@ -577,6 +577,7 @@ dev = [ { name = "pytest-snapshot" }, { name = "pytket" }, { name = "ruff" }, + { name = "tket2" }, ] lint = [ { name = "mypy" }, @@ -591,6 +592,7 @@ llvm-integration = [ { name = "pytest-cov" }, { name = "pytest-notebook" }, { name = "pytest-snapshot" }, + { name = "tket2" }, ] pytket-integration = [ { name = "ipykernel" }, @@ -599,6 +601,7 @@ pytket-integration = [ { name = "pytest-notebook" }, { name = "pytest-snapshot" }, { name = "pytket" }, + { name = "tket2" }, ] test = [ { name = "ipykernel" }, @@ -606,6 +609,7 @@ test = [ { name = "pytest-cov" }, { name = "pytest-notebook" }, { name = "pytest-snapshot" }, + { name = "tket2" }, ] [package.metadata] @@ -635,6 +639,7 @@ dev = [ { name = "pytest-snapshot", specifier = ">=0.9.0,<1" }, { name = "pytket", specifier = ">=1.34.0,<2" }, { name = "ruff", specifier = ">=0.6.2,<0.7" }, + { name = "tket2", specifier = ">=0.5.0" }, ] lint = [ { name = "mypy", specifier = "==1.10.0" }, @@ -649,6 +654,7 @@ llvm-integration = [ { name = "pytest-cov", specifier = ">=5.0.0,<6" }, { name = "pytest-notebook", specifier = ">=0.10.0,<0.11" }, { name = "pytest-snapshot", specifier = ">=0.9.0,<1" }, + { name = "tket2", specifier = ">=0.5.0" }, ] pytket-integration = [ { name = "ipykernel", specifier = ">=6.29.5,<7" }, @@ -657,6 +663,7 @@ pytket-integration = [ { name = "pytest-notebook", specifier = ">=0.10.0,<0.11" }, { name = "pytest-snapshot", specifier = ">=0.9.0,<1" }, { name = "pytket", specifier = ">=1.34.0,<2" }, + { name = "tket2", specifier = ">=0.5.0" }, ] test = [ { name = "ipykernel", specifier = ">=6.29.5,<7" }, @@ -664,6 +671,7 @@ test = [ { name = "pytest-cov", specifier = ">=5.0.0,<6" }, { name = "pytest-notebook", specifier = ">=0.10.0,<0.11" }, { name = "pytest-snapshot", specifier = ">=0.9.0,<1" }, + { name = "tket2", specifier = ">=0.5.0" }, ] [[package]] @@ -2241,6 +2249,80 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/4d/0db5b8a613d2a59bbc29bc5bb44a2f8070eb9ceab11c50d477502a8a0092/tinycss2-1.3.0-py3-none-any.whl", hash = "sha256:54a8dbdffb334d536851be0226030e9505965bb2f30f21a4a82c55fb2a80fae7", size = 22532 }, ] +[[package]] +name = "tket2" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hugr" }, + { name = "pytket" }, + { name = "tket2-eccs" }, + { name = "tket2-exts" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/7c/2aaa2503ea26ab81c06ba964af97cac28904064fef9c122dbe7e82f921c9/tket2-0.5.0.tar.gz", hash = "sha256:0e983f933e9231bebc6fdaf8c3ffa56b8adf13636c51c06f775664524141f19c", size = 224173 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/a0/34359c683c6657d33c6e92ea565e5138357ba68d7acbc7825259e11616fc/tket2-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:570b1cf6b83fc8c0ac9a67c4e40a8faf5daed8c0b664f64d54ed0746d9ad3789", size = 3803598 }, + { url = "https://files.pythonhosted.org/packages/8d/07/172d33c9f90413d5cdcb09873b040db08b16fa9385c17f0ed6f7a6e12c8e/tket2-0.5.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:86366fadb8740911222a3f3522454badee80f23c142b629753955b7aa2dd351e", size = 4474573 }, + { url = "https://files.pythonhosted.org/packages/6c/9c/c23ae608eb3cd90b3f2b8269080afd385d06c9979176b989a6106fc414bb/tket2-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59afbe455dfb274a6874f03dce6040125aba4fb0c854cddf23edc546ff7e8df3", size = 3939809 }, + { url = "https://files.pythonhosted.org/packages/4f/bf/1dde175e668ab30fda7e09e645c27417a9879b4eb7ad5f8c5152ab2c7c96/tket2-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:31c73a391beef828011f9449d15e12541a31bae25dd52cdfac74b8ffab112f54", size = 3879199 }, + { url = "https://files.pythonhosted.org/packages/44/a0/d9bbb52b9801d4aae9d97cbd30d985052280c62464ed160716cca3c0b0b6/tket2-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d988087136b098155b345bdd4a232d5df6a802adc2a5e3733a3eee642182f4b2", size = 4521382 }, + { url = "https://files.pythonhosted.org/packages/80/f6/9c0c0e6616fed354987b40a8f0b37317a2394ae91c343c2d389b4dbea2fa/tket2-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:74c531f798b1ee3c5a8a3f5111205d52449125bfe4dff825888efdde460a5a69", size = 5459797 }, + { url = "https://files.pythonhosted.org/packages/4c/28/bf5e277c72e33f25cabb58f0674b01ca119f813a3170793f06477fafbdac/tket2-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fb4482055c7d7beceb8c34beaa1df4e6c0b51f721c141203795f5ac438518b7", size = 4333591 }, + { url = "https://files.pythonhosted.org/packages/33/81/8270a5eb9b5f0c9e06352263e7a55adb3b3dc6550f4e7ea8b6d001a00844/tket2-0.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:afd8a41c08001598173df9fa1eedf3ee90a7ee965845b82c148e77b670eff2c1", size = 4118350 }, + { url = "https://files.pythonhosted.org/packages/e5/c1/57ea758e74f73583bb9622d7c290c77df58c71eccb75171d52b6f2f5c664/tket2-0.5.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:c22312a6d1bb8bf736a2234ad6cbde8d02d56968ecc6a8c788511d50ef87dec0", size = 4104065 }, + { url = "https://files.pythonhosted.org/packages/8b/d4/538ef3cd4acb45b381bce4eef23d049a731907166efa2c621e3f11d4f4b9/tket2-0.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8c621b284a4d56283aa3a717dd22764906ff7bb237328b2cdc4bb31e994305f5", size = 4409060 }, + { url = "https://files.pythonhosted.org/packages/eb/7e/efd9af64a68320428169127f527bf7a00b8b45d22cab8631dd717b6b8580/tket2-0.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9ff8f8e7cf80473715a093d5336b3ff2e9140bb8448647dc57326e8643feb2b8", size = 4474198 }, + { url = "https://files.pythonhosted.org/packages/7f/c0/cee7847316a04cd8136f325aa2275d7aa0601f4573995f1a625eaa031e65/tket2-0.5.0-cp310-none-win32.whl", hash = "sha256:2071b0759c0c4ae81f8d5bc35f3f8c9139a4d6e2ee7e5f8bd4c9317eecbd9939", size = 3744566 }, + { url = "https://files.pythonhosted.org/packages/4e/87/98dc94fe902b74543d01cb83695d4080285ca95c116c93e0210d092eaf3f/tket2-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:d4f0960d55fa77cdf0579aff7e8cd49f902092ce728b1a7cd4741e59e5e2cd11", size = 4196769 }, + { url = "https://files.pythonhosted.org/packages/10/d1/f3356e7368321ad768202920d84845af7ffe48e1479ab82218d75c9925fd/tket2-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:3355f90c08be7be1f04de3b1da47ab092d9f0f9aaee427a4b385a4c778acaa4e", size = 4069056 }, + { url = "https://files.pythonhosted.org/packages/31/53/a1d1ff4f3bbfa99a766b27be9567a297c935a0d269db94d9e20e06d93af5/tket2-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7ee7c6bbc6bfa81040581390323931a2d2e65d0fe28af43bb9fd431e1b3a97a9", size = 3803758 }, + { url = "https://files.pythonhosted.org/packages/d9/6e/37910978999fa248a70a7e9c6704664d8da31204c99944cf294ebe5af7fb/tket2-0.5.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:85c9107c367e8ed95409137ed43f6fc47f5f698881db12b0505c024569f0cb1d", size = 4475499 }, + { url = "https://files.pythonhosted.org/packages/d1/1e/d01ad167e17d3bc715102579970bb1d8e74b89f56c7fecaab098216aa22b/tket2-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4c93f3f43dab15639a7ecf5908a4b40f8c7499aa8c9d2b1c3fb99b43438d2f2", size = 3938290 }, + { url = "https://files.pythonhosted.org/packages/07/65/0f3f6bd294dc4dd2f4eb009da1cb5950550ca7901e55f98867e188aeced5/tket2-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61aacc91facb7ffe1a633311fe4493aa7c5a62bc7ada783f91e6ff33e3c6eba5", size = 3884050 }, + { url = "https://files.pythonhosted.org/packages/e1/c4/405987b62243ff28cc542dea0428a4de2b5517e39c4f0cc92cb7688ee21a/tket2-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f1c0dc27604771cace9c91aa133d14302522536f468e2a7eb1c94ab73ed131f", size = 4521538 }, + { url = "https://files.pythonhosted.org/packages/1a/81/d879e3e03aff7da2557d56894fa6c82f73774ddb64a972a928b744de9b2a/tket2-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6dcf2ed5b30ccaaa8f618e626be939e14043db6f98b17a3ff2b9b9ada9194101", size = 5459463 }, + { url = "https://files.pythonhosted.org/packages/b5/1f/d24edb71b6146cfa1fda02841b1962f6e35037adff96834e7a77eb10b2ba/tket2-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:890906786ba1d4874d726155d063fdcc62286c59077aa9b776a2b404ddd67684", size = 4336956 }, + { url = "https://files.pythonhosted.org/packages/26/d2/4e15b43ebbb2f616b570aaa432f2f65f7877a5bf304c282a49af60a98d86/tket2-0.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:900c3d08be38739a2d7da8892809be02489ec4c0f127477da49a6be4507bb540", size = 4116397 }, + { url = "https://files.pythonhosted.org/packages/ea/b4/0181d5ca54afaee7b267501138dcd8db41deb622f8d8d3c9bad32ab33326/tket2-0.5.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:677297a5fecd73734342d5f1b17f178313817622843f5390dc7e662f64bac280", size = 4105606 }, + { url = "https://files.pythonhosted.org/packages/1b/40/5af365084dbf9f5443ea953221cb67465c7a9d43cf374f6b91bb9a35f262/tket2-0.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1559bd173e1b565f6da28288ea10a7d4ffd0daef82c5f4637263541b35c9d68f", size = 4406467 }, + { url = "https://files.pythonhosted.org/packages/3b/48/6a8f685882221093fffcd673957343ef0ebc703b4a33f8ad64fd30684c3f/tket2-0.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1173ecaf47200e79920bae4b7e4a1398e0a58c4ac484f2297729ae8d5ca9ec8a", size = 4476118 }, + { url = "https://files.pythonhosted.org/packages/16/2b/a1a3e506ab5394e564a2af07fe113c521f6511baa07a14e322c3df2f46a4/tket2-0.5.0-cp311-none-win32.whl", hash = "sha256:0e759d4c416a49d22f335706855d5d53b506a38c37981f5e4f9baa65c40c4ec7", size = 3745064 }, + { url = "https://files.pythonhosted.org/packages/eb/ee/7747f464389c30c6a6402d432d23bd6763b892cdc49660b28fb6d65362ed/tket2-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:d4df9b9acaff533e755dc9b0b052f7b02c7853a5021403c45b505933d3f87ffd", size = 4195959 }, + { url = "https://files.pythonhosted.org/packages/38/67/2bfdbaeab2e19ff6cf31a7d8ee9a7e78024972ee6c1ba205603555b68046/tket2-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:eb1eefb201b4fb5c589cb69a0bfd0ad43167ebf0886701db25f21979cd07cd3a", size = 4067909 }, + { url = "https://files.pythonhosted.org/packages/69/fc/fed0208573aa6f6dae6fa855405f50a20831096f55cee890d288a3b59529/tket2-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6008ea83f850c8827ae99435a290540e3ff216b81ca6f00e4c35988597a1cd6e", size = 3804559 }, + { url = "https://files.pythonhosted.org/packages/78/54/1d47c0eb57e04bad3467ff2d541c55bb127493bda1c42eaf28225a8533dd/tket2-0.5.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:99162729722b8b4733d8f14d11a208890e0022d12bd69624d36ce67484c35153", size = 4470733 }, + { url = "https://files.pythonhosted.org/packages/58/33/4da7a02cf9232354fe2216a5bcfdd45db382b30554ee3dcd1ff212bb7dca/tket2-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d20577b4f92a83d1933c59c79b837fb87372d2fcccf46936e8426f83103f36e9", size = 3936658 }, + { url = "https://files.pythonhosted.org/packages/26/15/ee480158780871c87e8c91dfb29063972b7895ea1bade1a284d8fba52cfc/tket2-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bcd4a249cb2bb9ebc11d23246a7053518f596d9ad3e7f82e046c4092f4301434", size = 3875720 }, + { url = "https://files.pythonhosted.org/packages/00/04/a5f7ce686dde146f7dcf8ef9c4fcccf8a86e1072301466b697ddd7148b98/tket2-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9729a028568287b2b66539ce65f43db07107dfa59622cdf74410b7ec270f626", size = 4515140 }, + { url = "https://files.pythonhosted.org/packages/a9/9c/9f54034f18f6bf127b057f36d51e61e92d5b9b7b9397a3f0b68cffec5044/tket2-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:075f3c90d8b747b9b24845fad9f6d84ae2428d80ad0d9c4e6d8eda7f207fc789", size = 5438753 }, + { url = "https://files.pythonhosted.org/packages/0d/06/650d41c6260fd71f73b11d05427821820b9c4b8633bbeb19093bdaf9b99d/tket2-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b47e9c6dec87b9360027d3f9127770f68483fdbe2202ddc145a50369b8af5428", size = 4340243 }, + { url = "https://files.pythonhosted.org/packages/39/3f/3b823959f0d9629f05d23fc57158473589e9ad6dbec7ec006d68662c27b4/tket2-0.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:81e9d8fb1a142f2375c1ec4fb38b7fbbf652efdbd2056496ed89196b682e5caf", size = 4123197 }, + { url = "https://files.pythonhosted.org/packages/e1/dc/8fa6baa5d64c9f0a41bf47dafd609c59caf5bb9caa041f6d278fb623f9aa/tket2-0.5.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:773dae0435c80e21b8e9e492e378f1f188be0e33d1ae0c616d6f7d0f50507029", size = 4098065 }, + { url = "https://files.pythonhosted.org/packages/d5/1c/a31d8a887670814a68047d536328573d234ed99e8356339c5692a53b7767/tket2-0.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4bcf1072dc55bd6d167c4044290984f8dc1785827b6a4ecfa880dd675e5fccb5", size = 4403253 }, + { url = "https://files.pythonhosted.org/packages/fd/8c/9ef0714d98a9649488bb91ec3f1c66ac6b526f725cccccf552522dc945fa/tket2-0.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c55fa3cf3252e9d3f87782650cf586409eb1c1ce095b361621389cf8537fb74f", size = 4476375 }, + { url = "https://files.pythonhosted.org/packages/2a/d6/b850caee0d9b7920750c5b9636f36ecd3562916414f6a1a852db5ea80fe6/tket2-0.5.0-cp312-none-win32.whl", hash = "sha256:5de9ef9b56ff47070581886472161a926ee4e42394d82d5930110e83733ff61d", size = 3749065 }, + { url = "https://files.pythonhosted.org/packages/f8/ac/0968af4b6847ebd03a94595e8846ad982c54d8fe7c9dd5233930e6b8676e/tket2-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:e96f1e1d1d9fa4c11e0c340ceae3057799419905d376f8eeb61f16055d2161b3", size = 4205451 }, + { url = "https://files.pythonhosted.org/packages/91/f4/2a1073033573a654a78c2dcbd6c93f794000063fd9355fae956a95d6b817/tket2-0.5.0-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:94b11bc7dee0e0d706b917a9aca9d300b2a861d8878d45be3136a2218cc88a2f", size = 4470143 }, + { url = "https://files.pythonhosted.org/packages/18/37/05f44f8912ddf76033bef52c08cbd4f9258746b06be50fe07aa6aa21792d/tket2-0.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28cd921b7a4ea190d2b012b5938a75a7e4f52633af8f94667f8d58b1d96c99dc", size = 3941477 }, + { url = "https://files.pythonhosted.org/packages/5d/3e/e94ae654b9559ed28ad24c0fc2c25891c0b9b7fe03d8fc87cd1d461db0d7/tket2-0.5.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d69b9c8fb1f2cc2164a8f21cd58baa8210a7ae2e577a3705838d80a5f51c1d91", size = 3883235 }, + { url = "https://files.pythonhosted.org/packages/09/1a/aceb016b53fc1bffe900704e9e298229e67e7d4ee2a03769aaecba2123d6/tket2-0.5.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a89601fb6f3cd224b1c545b6ce5604431caf5e0b86235f070fd835e2bc32e4fa", size = 4521748 }, + { url = "https://files.pythonhosted.org/packages/66/72/7b91d44aeae94516ec59e3a09b4ab421f2e6a61822c09d1129f47dcd4dd6/tket2-0.5.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:232c06a95f45774aeaccaf736a7b51f3ce4b27eaa575e56b44c76012e62d821d", size = 5457635 }, + { url = "https://files.pythonhosted.org/packages/31/42/ec01391ce07794002160dd928b45c2330fbdb15fd7cd97934b652fb24cd3/tket2-0.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f06d9f0a191c010bd8b6ad826bc1aa10f8b8566f25cc971407482100799a4a48", size = 4337550 }, + { url = "https://files.pythonhosted.org/packages/7e/78/4f1169799bae4fda6dca90b1a44f15f5d5962d7d14353fab754c84e351db/tket2-0.5.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:8b0cd2afd8b088f37ee4e26b45d9f923526c91446755523bab9381266f616892", size = 4123793 }, + { url = "https://files.pythonhosted.org/packages/f7/73/f311aba98b5dfc619ac60f57fa30e1642d091b2dcd5ab98b0bad6b277f90/tket2-0.5.0-pp310-pypy310_pp73-musllinux_1_2_armv7l.whl", hash = "sha256:d0501c984ee7f6adf73b56ef0c7d1d7e2db6e3a9c1d44f3e9e66f50af652d57e", size = 4101292 }, + { url = "https://files.pythonhosted.org/packages/21/8f/dbae07d8beeedfa1e1788b921c61009734209452809f676c1934e26c0177/tket2-0.5.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:df97424b2e64a804887d17e0e0f3be849fbd88cbf6a13563dc00842a478d6bbf", size = 4413372 }, + { url = "https://files.pythonhosted.org/packages/4c/2a/dc3a947a7195c715bc161ad303afc0f4b836718389161ad4783573ff3748/tket2-0.5.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:a4d8d23ccb9f28d2c8640a274521747f71d3ea839b64564820e4ffe317b31e42", size = 4474374 }, +] + +[[package]] +name = "tket2-eccs" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/75/6a19814c04f4b69bb464b1b4ac80fb317f4ed77bf1474f5474ac9e0c9d4a/tket2_eccs-0.2.0.tar.gz", hash = "sha256:201e52f2eaa6eb54df814f331ad597d3733e12d8fd6c4f8ad571572460c2f62b", size = 4428924 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/1f/9dbd087f3fc1195dd4c9c5a320923a504dddbdbd4ae3efada551cb291dfe/tket2_eccs-0.2.0-py3-none-any.whl", hash = "sha256:b280f7112fb743383ecd6077c8aa385a3f7f909b7618c345bbebe3c56ca3eb7f", size = 4431360 }, +] + [[package]] name = "tket2-exts" version = "0.2.0" From a56ab6b44fe42d419a7619d57d24f668699aa819 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Wed, 4 Dec 2024 17:17:01 +0000 Subject: [PATCH 19/21] Remove test skips --- tests/integration/test_pytket_circuits.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index e117af7e..dc3e1aa0 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -33,7 +33,6 @@ def foo(q: qubit) -> None: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("remove this") def test_multi_qubit_circuit(validate): from pytket import Circuit @@ -76,7 +75,6 @@ def foo(q: qubit) -> bool: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("remove this") def test_measure_multiple(validate): from pytket import Circuit @@ -98,7 +96,6 @@ def foo(q1: qubit, q2: qubit) -> tuple[bool, bool]: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("remove this") def test_measure_not_last(validate): from pytket import Circuit From 99ee62180fb4db3e229a1d258df6b594dc96e8b1 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Fri, 6 Dec 2024 09:42:37 +0000 Subject: [PATCH 20/21] Address comments --- guppylang/checker/errors/py_errors.py | 4 +++- guppylang/definition/pytket_circuits.py | 21 ++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/guppylang/checker/errors/py_errors.py b/guppylang/checker/errors/py_errors.py index 1c5a984e..2ba134f5 100644 --- a/guppylang/checker/errors/py_errors.py +++ b/guppylang/checker/errors/py_errors.py @@ -58,5 +58,7 @@ class InstallInstruction(Help): @dataclass(frozen=True) class PytketSignatureMismatch(Error): title: ClassVar[str] = "Signature mismatch" - span_label: ClassVar[str] = "Signature {name} doesn't match provided pytket circuit" + span_label: ClassVar[str] = ( + "Signature `{name}` doesn't match provided pytket circuit" + ) name: str diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index df5a377f..caaf5b02 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -3,7 +3,7 @@ from typing import Any, cast import hugr.build.function as hf -from hugr import Hugr, OutPort, Wire, ops, tys, val +from hugr import Hugr, Wire, val from hugr.build.dfg import DefinitionBuilder, OpVar from guppylang.ast_util import AstNode, has_empty_body, with_loc @@ -158,21 +158,20 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" # Initialise every input bit in the circuit as false. # TODO: Provide the option for the user to pass this input as well. - bool_wires: list[OutPort] = [] - for child in module.hugr.children(hugr_func): - node_data = module.hugr.get(child) - if node_data and isinstance(node_data.op, ops.Input): - input_types = node_data.op.types - num_bools = input_types.count(tys.Bool) - for _ in range(num_bools): - bool_node = outer_func.load(val.FALSE) - bool_wires.append(*bool_node.outputs()) + bool_wires = [ + outer_func.load(val.FALSE) for _ in range(self.input_circuit.n_bits) + ] call_node = outer_func.call( hugr_func, *(list(outer_func.inputs()) + bool_wires) ) # Pytket circuit hugr has qubit and bool wires in the opposite order. - outer_func.set_outputs(*reversed(list(call_node.outputs()))) + output_list = list(call_node.outputs()) + wires = ( + output_list[self.input_circuit.n_qubits :] + + output_list[: self.input_circuit.n_qubits] + ) + outer_func.set_outputs(*wires) except ImportError: pass From 0864123d1ff2cf2322f24d8b49c5a8f61eb104d5 Mon Sep 17 00:00:00 2001 From: Tatiana S Date: Fri, 6 Dec 2024 09:51:22 +0000 Subject: [PATCH 21/21] Fix error test --- tests/error/py_errors/sig_mismatch.err | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/error/py_errors/sig_mismatch.err b/tests/error/py_errors/sig_mismatch.err index ba5c8d31..47f22396 100644 --- a/tests/error/py_errors/sig_mismatch.err +++ b/tests/error/py_errors/sig_mismatch.err @@ -3,6 +3,7 @@ Error: Signature mismatch (at $FILE:15:0) 13 | 14 | @guppy.pytket(circ, module) 15 | def guppy_circ(q: qubit) -> None: ... - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Signature guppy_circ doesn't match provided pytket circuit + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Signature `guppy_circ` doesn't match provided pytket + | circuit Guppy compilation failed due to 1 previous error