diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py new file mode 100644 index 000000000..33f792d0b --- /dev/null +++ b/hugr-py/tests/conftest.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import subprocess +import os +import pathlib +from hugr.node_port import Wire + +from hugr.hugr import Hugr +from hugr.ops import Custom, Command +from hugr.serialization import SerialHugr +import hugr.tys as tys +import hugr.val as val +import json + + +def int_t(width: int) -> tys.Opaque: + return tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.BoundedNatArg(n=width)], + bound=tys.TypeBound.Eq, + ) + + +INT_T = int_t(5) + + +@dataclass +class IntVal(val.ExtensionValue): + v: int + + def to_value(self) -> val.Extension: + return val.Extension("int", INT_T, self.v) + + +@dataclass +class LogicOps(Custom): + extension: tys.ExtensionId = "logic" + + +# TODO get from YAML +@dataclass +class NotDef(LogicOps): + num_out: int | None = 1 + op_name: str = "Not" + signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) + + def __call__(self, a: Wire) -> Command: + return super().__call__(a) + + +Not = NotDef() + + +@dataclass +class QuantumOps(Custom): + extension: tys.ExtensionId = "tket2.quantum" + + +@dataclass +class OneQbGate(QuantumOps): + op_name: str + num_out: int | None = 1 + signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +H = OneQbGate("H") + + +@dataclass +class MeasureDef(QuantumOps): + op_name: str = "Measure" + num_out: int | None = 2 + signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +Measure = MeasureDef() + + +@dataclass +class IntOps(Custom): + extension: tys.ExtensionId = "arithmetic.int" + + +ARG_5 = tys.BoundedNatArg(n=5) + + +@dataclass +class DivModDef(IntOps): + num_out: int | None = 2 + extension: tys.ExtensionId = "arithmetic.int" + op_name: str = "idivmod_u" + signature: tys.FunctionType = field( + default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + ) + args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) + + +DivMod = DivModDef() + + +def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): + workspace_dir = pathlib.Path(__file__).parent.parent.parent + # use the HUGR_BIN environment variable if set, otherwise use the debug build + bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr")) + cmd = [bin_loc, "-"] + + if mermaid: + cmd.append("--mermaid") + serial = h.to_serial().to_json() + subprocess.run(cmd, check=True, input=serial.encode()) + + if roundtrip: + h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) + assert serial == h2.to_serial().to_json() diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 3fbfbd1a0..ddd3b9c32 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -3,7 +3,7 @@ import hugr.val as val from hugr.dfg import Dfg import hugr.ops as ops -from .test_hugr_build import _validate, INT_T, DivMod, IntVal +from .conftest import validate, INT_T, DivMod, IntVal def build_basic_cfg(cfg: Cfg) -> None: @@ -16,7 +16,7 @@ def build_basic_cfg(cfg: Cfg) -> None: def test_basic_cfg() -> None: cfg = Cfg([tys.Bool]) build_basic_cfg(cfg) - _validate(cfg.hugr) + validate(cfg.hugr) def test_branch() -> None: @@ -34,7 +34,7 @@ def test_branch() -> None: cfg.branch_exit(middle_1[0]) cfg.branch_exit(middle_2[0]) - _validate(cfg.hugr) + validate(cfg.hugr) def test_nested_cfg() -> None: @@ -45,7 +45,7 @@ def test_nested_cfg() -> None: build_basic_cfg(cfg) dfg.set_outputs(cfg) - _validate(dfg.hugr) + validate(dfg.hugr) def test_dom_edge() -> None: @@ -64,7 +64,7 @@ def test_dom_edge() -> None: cfg.branch_exit(middle_1[0]) cfg.branch_exit(middle_2[0]) - _validate(cfg.hugr) + validate(cfg.hugr) def test_asymm_types() -> None: @@ -86,7 +86,7 @@ def test_asymm_types() -> None: cfg.branch_exit(entry[1]) cfg.branch_exit(middle[0]) - _validate(cfg.hugr) + validate(cfg.hugr) # TODO loop diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 4e617786f..75791317b 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -4,7 +4,7 @@ import hugr.ops as ops import hugr.val as val import pytest -from .test_hugr_build import INT_T, _validate, IntVal, H, Measure +from .conftest import INT_T, validate, IntVal, H, Measure SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) @@ -25,7 +25,7 @@ def build_cond(h: Conditional) -> None: def test_cond() -> None: h = Conditional(SUM_T, [tys.Bool]) build_cond(h) - _validate(h.hugr) + validate(h.hugr) def test_nested_cond() -> None: @@ -35,7 +35,7 @@ def test_nested_cond() -> None: cond = h.add_conditional(tagged_q, h.load(val.TRUE)) build_cond(cond) h.set_outputs(*cond[:2]) - _validate(h.hugr) + validate(h.hugr) # build then insert con = Conditional(SUM_T, [tys.Bool]) @@ -46,7 +46,7 @@ def test_nested_cond() -> None: tagged_q = h.add(ops.Tag(0, SUM_T)(q)) cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE)) h.set_outputs(*cond_n[:2]) - _validate(h.hugr) + validate(h.hugr) def test_if_else() -> None: @@ -63,7 +63,7 @@ def test_if_else() -> None: cond = else_.finish() h.set_outputs(cond) - _validate(h.hugr) + validate(h.hugr) def test_tail_loop() -> None: @@ -80,7 +80,7 @@ def build_tl(tl: TailLoop) -> None: build_tl(tl) h.set_outputs(tl) - _validate(h.hugr) + validate(h.hugr) # build then insert tl = TailLoop([], [tys.Qubit]) @@ -90,7 +90,7 @@ def build_tl(tl: TailLoop) -> None: (q,) = h.inputs() tl_n = h.insert_tail_loop(tl, q) h.set_outputs(tl_n) - _validate(h.hugr) + validate(h.hugr) def test_complex_tail_loop() -> None: @@ -119,6 +119,6 @@ def test_complex_tail_loop() -> None: # loop returns [qubit, int, bool] h.set_outputs(*tl[:3]) - _validate(h.hugr, True) + validate(h.hugr, True) # TODO rewrite with context managers diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 37a930a14..a01b89d97 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -1,128 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass, field -import subprocess -import os -import pathlib -from hugr.node_port import Node, Wire, _SubPort +from hugr.node_port import Node, _SubPort from hugr.hugr import Hugr from hugr.dfg import Dfg, _ancestral_sibling -from hugr.ops import Custom, Command, NoConcreteFunc +from hugr.ops import NoConcreteFunc import hugr.ops as ops -from hugr.serialization import SerialHugr import hugr.tys as tys import hugr.val as val from hugr.function import Module import pytest -import json - -def int_t(width: int) -> tys.Opaque: - return tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=width)], - bound=tys.TypeBound.Eq, - ) - - -INT_T = int_t(5) - - -@dataclass -class IntVal(val.ExtensionValue): - v: int - - def to_value(self) -> val.Extension: - return val.Extension("int", INT_T, self.v) - - -@dataclass -class LogicOps(Custom): - extension: tys.ExtensionId = "logic" - - -# TODO get from YAML -@dataclass -class NotDef(LogicOps): - num_out: int | None = 1 - op_name: str = "Not" - signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) - - def __call__(self, a: Wire) -> Command: - return super().__call__(a) - - -Not = NotDef() - - -@dataclass -class QuantumOps(Custom): - extension: tys.ExtensionId = "tket2.quantum" - - -@dataclass -class OneQbGate(QuantumOps): - op_name: str - num_out: int | None = 1 - signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) - - def __call__(self, q: Wire) -> Command: - return super().__call__(q) - - -H = OneQbGate("H") - - -@dataclass -class MeasureDef(QuantumOps): - op_name: str = "Measure" - num_out: int | None = 2 - signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) - - def __call__(self, q: Wire) -> Command: - return super().__call__(q) - - -Measure = MeasureDef() - - -@dataclass -class IntOps(Custom): - extension: tys.ExtensionId = "arithmetic.int" - - -ARG_5 = tys.BoundedNatArg(n=5) - - -@dataclass -class DivModDef(IntOps): - num_out: int | None = 2 - extension: tys.ExtensionId = "arithmetic.int" - op_name: str = "idivmod_u" - signature: tys.FunctionType = field( - default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) - ) - args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) - - -DivMod = DivModDef() - - -def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): - workspace_dir = pathlib.Path(__file__).parent.parent.parent - # use the HUGR_BIN environment variable if set, otherwise use the debug build - bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr")) - cmd = [bin_loc, "-"] - - if mermaid: - cmd.append("--mermaid") - serial = h.to_serial().to_json() - subprocess.run(cmd, check=True, input=serial.encode()) - - if roundtrip: - h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) - assert serial == h2.to_serial().to_json() +from .conftest import Not, INT_T, IntVal, validate, DivMod def test_stable_indices(): @@ -168,7 +56,7 @@ def simple_id() -> Dfg: def test_simple_id(): - _validate(simple_id().hugr) + validate(simple_id().hugr) def test_multiport(): @@ -191,7 +79,7 @@ def test_multiport(): ] assert list(h.hugr.linked_ports(ou_n.inp(0))) == [in_n.out(0)] - _validate(h.hugr) + validate(h.hugr) def test_add_op(): @@ -200,7 +88,7 @@ def test_add_op(): nt = h.add_op(Not, a) h.set_outputs(nt) - _validate(h.hugr) + validate(h.hugr) def test_tuple(): @@ -211,7 +99,7 @@ def test_tuple(): a, b = h.add(ops.UnpackTuple(t)) h.set_outputs(a, b) - _validate(h.hugr) + validate(h.hugr) h1 = Dfg(*row) a, b = h1.inputs() @@ -227,7 +115,7 @@ def test_multi_out(): a, b = h.inputs() a, b = h.add(DivMod(a, b)) h.set_outputs(a, b) - _validate(h.hugr) + validate(h.hugr) def test_insert(): @@ -254,7 +142,7 @@ def test_insert_nested(): nested = h.insert_nested(h1, a) h.set_outputs(nested) assert len(h.hugr.children(nested)) == 3 - _validate(h.hugr) + validate(h.hugr) def test_build_nested(): @@ -271,7 +159,7 @@ def _nested_nop(dfg: Dfg): assert len(h.hugr.children(nested)) == 3 h.set_outputs(nested) - _validate(h.hugr) + validate(h.hugr) def test_build_inter_graph(): @@ -284,7 +172,7 @@ def test_build_inter_graph(): h.set_outputs(nested, b) - _validate(h.hugr) + validate(h.hugr) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order @@ -315,7 +203,7 @@ def test_vals(val: val.Value): d = Dfg() d.set_outputs(d.load(val)) - _validate(d.hugr) + validate(d.hugr) @pytest.mark.parametrize("direct_call", [True, False]) @@ -348,7 +236,7 @@ def test_poly_function(direct_call: bool) -> None: f_main.set_outputs(call) - _validate(mod.hugr, True) + validate(mod.hugr, True) @pytest.mark.parametrize("direct_call", [True, False]) @@ -367,7 +255,7 @@ def test_mono_function(direct_call: bool) -> None: call = f_main.add(ops.CallIndirect(load, q)) f_main.set_outputs(call) - _validate(mod.hugr) + validate(mod.hugr) def test_higher_order() -> None: @@ -380,7 +268,7 @@ def test_higher_order() -> None: call = d.add(ops.CallIndirect(f_val, q))[0] d.set_outputs(call) - _validate(d.hugr) + validate(d.hugr) def test_lift() -> None: @@ -388,7 +276,7 @@ def test_lift() -> None: (q,) = d.inputs() lift = d.add(ops.Lift("X")(q)) d.set_outputs(lift) - _validate(d.hugr) + validate(d.hugr) def test_alias() -> None: @@ -396,4 +284,4 @@ def test_alias() -> None: _dfn = mod.add_alias_defn("my_int", INT_T) _dcl = mod.add_alias_decl("my_bool", tys.TypeBound.Eq) - _validate(mod.hugr) + validate(mod.hugr)