Skip to content

Commit

Permalink
test: move shared code in to conftest
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 27, 2024
1 parent 4c3390e commit 4329245
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 143 deletions.
122 changes: 122 additions & 0 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from dataclasses import dataclass, field
import subprocess
import os
import pathlib
from hugr.node_port import Wire

from hugr.hugr import Hugr
from hugr.ops import Custom, Command
from hugr.serialization import SerialHugr
import hugr.tys as tys
import hugr.val as val
import json


def int_t(width: int) -> tys.Opaque:
return tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.BoundedNatArg(n=width)],
bound=tys.TypeBound.Eq,
)


INT_T = int_t(5)


@dataclass
class IntVal(val.ExtensionValue):
v: int

def to_value(self) -> val.Extension:
return val.Extension("int", INT_T, self.v)


@dataclass
class LogicOps(Custom):
extension: tys.ExtensionId = "logic"


# TODO get from YAML
@dataclass
class NotDef(LogicOps):
num_out: int | None = 1
op_name: str = "Not"
signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool])

def __call__(self, a: Wire) -> Command:
return super().__call__(a)


Not = NotDef()


@dataclass
class QuantumOps(Custom):
extension: tys.ExtensionId = "tket2.quantum"


@dataclass
class OneQbGate(QuantumOps):
op_name: str
num_out: int | None = 1
signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit])

def __call__(self, q: Wire) -> Command:
return super().__call__(q)


H = OneQbGate("H")


@dataclass
class MeasureDef(QuantumOps):
op_name: str = "Measure"
num_out: int | None = 2
signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])

def __call__(self, q: Wire) -> Command:
return super().__call__(q)


Measure = MeasureDef()


@dataclass
class IntOps(Custom):
extension: tys.ExtensionId = "arithmetic.int"


ARG_5 = tys.BoundedNatArg(n=5)


@dataclass
class DivModDef(IntOps):
num_out: int | None = 2
extension: tys.ExtensionId = "arithmetic.int"
op_name: str = "idivmod_u"
signature: tys.FunctionType = field(
default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2)
)
args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5])


DivMod = DivModDef()


def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True):
workspace_dir = pathlib.Path(__file__).parent.parent.parent
# use the HUGR_BIN environment variable if set, otherwise use the debug build
bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr"))
cmd = [bin_loc, "-"]

if mermaid:
cmd.append("--mermaid")
serial = h.to_serial().to_json()
subprocess.run(cmd, check=True, input=serial.encode())

if roundtrip:
h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial)))
assert serial == h2.to_serial().to_json()
12 changes: 6 additions & 6 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hugr.val as val
from hugr.dfg import Dfg
import hugr.ops as ops
from .test_hugr_build import _validate, INT_T, DivMod, IntVal
from .conftest import validate, INT_T, DivMod, IntVal


def build_basic_cfg(cfg: Cfg) -> None:
Expand All @@ -16,7 +16,7 @@ def build_basic_cfg(cfg: Cfg) -> None:
def test_basic_cfg() -> None:
cfg = Cfg([tys.Bool])
build_basic_cfg(cfg)
_validate(cfg.hugr)
validate(cfg.hugr)


def test_branch() -> None:
Expand All @@ -34,7 +34,7 @@ def test_branch() -> None:
cfg.branch_exit(middle_1[0])
cfg.branch_exit(middle_2[0])

_validate(cfg.hugr)
validate(cfg.hugr)


def test_nested_cfg() -> None:
Expand All @@ -45,7 +45,7 @@ def test_nested_cfg() -> None:
build_basic_cfg(cfg)
dfg.set_outputs(cfg)

_validate(dfg.hugr)
validate(dfg.hugr)


def test_dom_edge() -> None:
Expand All @@ -64,7 +64,7 @@ def test_dom_edge() -> None:
cfg.branch_exit(middle_1[0])
cfg.branch_exit(middle_2[0])

_validate(cfg.hugr)
validate(cfg.hugr)


def test_asymm_types() -> None:
Expand All @@ -86,7 +86,7 @@ def test_asymm_types() -> None:
cfg.branch_exit(entry[1])
cfg.branch_exit(middle[0])

_validate(cfg.hugr)
validate(cfg.hugr)


# TODO loop
16 changes: 8 additions & 8 deletions hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import hugr.ops as ops
import hugr.val as val
import pytest
from .test_hugr_build import INT_T, _validate, IntVal, H, Measure
from .conftest import INT_T, validate, IntVal, H, Measure

SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]])

Expand All @@ -25,7 +25,7 @@ def build_cond(h: Conditional) -> None:
def test_cond() -> None:
h = Conditional(SUM_T, [tys.Bool])
build_cond(h)
_validate(h.hugr)
validate(h.hugr)


def test_nested_cond() -> None:
Expand All @@ -35,7 +35,7 @@ def test_nested_cond() -> None:
cond = h.add_conditional(tagged_q, h.load(val.TRUE))
build_cond(cond)
h.set_outputs(*cond[:2])
_validate(h.hugr)
validate(h.hugr)

# build then insert
con = Conditional(SUM_T, [tys.Bool])
Expand All @@ -46,7 +46,7 @@ def test_nested_cond() -> None:
tagged_q = h.add(ops.Tag(0, SUM_T)(q))
cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE))
h.set_outputs(*cond_n[:2])
_validate(h.hugr)
validate(h.hugr)


def test_if_else() -> None:
Expand All @@ -63,7 +63,7 @@ def test_if_else() -> None:
cond = else_.finish()
h.set_outputs(cond)

_validate(h.hugr)
validate(h.hugr)


def test_tail_loop() -> None:
Expand All @@ -80,7 +80,7 @@ def build_tl(tl: TailLoop) -> None:
build_tl(tl)
h.set_outputs(tl)

_validate(h.hugr)
validate(h.hugr)

# build then insert
tl = TailLoop([], [tys.Qubit])
Expand All @@ -90,7 +90,7 @@ def build_tl(tl: TailLoop) -> None:
(q,) = h.inputs()
tl_n = h.insert_tail_loop(tl, q)
h.set_outputs(tl_n)
_validate(h.hugr)
validate(h.hugr)


def test_complex_tail_loop() -> None:
Expand Down Expand Up @@ -119,6 +119,6 @@ def test_complex_tail_loop() -> None:
# loop returns [qubit, int, bool]
h.set_outputs(*tl[:3])

_validate(h.hugr, True)
validate(h.hugr, True)

# TODO rewrite with context managers
Loading

0 comments on commit 4329245

Please sign in to comment.