Skip to content

Commit

Permalink
feat: add wrapper for tagging Some, Left, Right, Break, Continue (#1814)
Browse files Browse the repository at this point in the history
Closes #1808
  • Loading branch information
ss2165 authored Dec 18, 2024
1 parent ab94518 commit f0385a0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
55 changes: 55 additions & 0 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,61 @@ def __repr__(self) -> str:
return f"Tag({self.tag})"


@dataclass
class Some(Tag):
"""Tag operation for the `Some` variant of an Option type.
Example:
# construct a Some variant holding a row of Bool and Unit types
>>> Some(tys.Bool, tys.Unit)
Some
"""

def __init__(self, *some_tys: tys.Type) -> None:
super().__init__(1, tys.Option(*some_tys))

def __repr__(self) -> str:
return "Some"


@dataclass
class Right(Tag):
"""Tag operation for the `Right` variant of an type."""

def __init__(self, either_type: tys.Either) -> None:
super().__init__(1, either_type)

def __repr__(self) -> str:
return "Right"


@dataclass
class Left(Tag):
"""Tag operation for the `Left` variant of an type."""

def __init__(self, either_type: tys.Either) -> None:
super().__init__(0, either_type)

def __repr__(self) -> str:
return "Left"


class Continue(Left):
"""Tag operation for the `Continue` variant of a TailLoop
controlling Either type.
"""

def __repr__(self) -> str:
return "Continue"


class Break(Right):
"""Tag operation for the `Break` variant of a TailLoop controlling Either type."""

def __repr__(self) -> str:
return "Break"


class DfParentOp(Op, Protocol):
"""Abstract parent of dataflow graph operations. Can be queried for the
dataflow signature of its child graph.
Expand Down
16 changes: 8 additions & 8 deletions hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .conftest import QUANTUM_EXT, H, Measure, validate

SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]])
EITHER_T = tys.Either([tys.Qubit], [tys.Qubit, INT_T])


def build_cond(h: Conditional) -> None:
Expand All @@ -25,15 +25,15 @@ def build_cond(h: Conditional) -> None:


def test_cond() -> None:
h = Conditional(SUM_T, [tys.Bool])
h = Conditional(EITHER_T, [tys.Bool])
build_cond(h)
validate(h.hugr)


def test_nested_cond() -> None:
h = Dfg(tys.Qubit)
(q,) = h.inputs()
tagged_q = h.add(ops.Tag(0, SUM_T)(q))
tagged_q = h.add(ops.Left(EITHER_T)(q))

with h.add_conditional(tagged_q, h.load(val.TRUE)) as cond:
build_cond(cond)
Expand All @@ -42,12 +42,12 @@ def test_nested_cond() -> None:
validate(h.hugr)

# build then insert
con = Conditional(SUM_T, [tys.Bool])
con = Conditional(EITHER_T, [tys.Bool])
build_cond(con)

h = Dfg(tys.Qubit)
(q,) = h.inputs()
tagged_q = h.add(ops.Tag(0, SUM_T)(q))
tagged_q = h.add(ops.Left(EITHER_T)(q))
cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE))
h.set_outputs(*cond_n[:2])
validate(h.hugr)
Expand All @@ -70,7 +70,7 @@ def test_if_else() -> None:

def test_incomplete() -> None:
def _build_incomplete():
with Conditional(SUM_T, [tys.Bool]) as c, c.add_case(0) as case0:
with Conditional(EITHER_T, [tys.Bool]) as c, c.add_case(0) as case0:
q, b = case0.inputs()
case0.set_outputs(q, b)

Expand Down Expand Up @@ -118,13 +118,13 @@ def test_complex_tail_loop() -> None:
# if b is true, return first variant (just qubit)
with tl.add_if(b, q) as if_:
(q,) = if_.inputs()
tagged_q = if_.add(ops.Tag(0, SUM_T)(q))
tagged_q = if_.add(ops.Continue(EITHER_T)(q))
if_.set_outputs(tagged_q)

# else return second variant (qubit, int)
with if_.add_else() as else_:
(q,) = else_.inputs()
tagged_q_i = else_.add(ops.Tag(1, SUM_T)(q, else_.load(IntVal(1))))
tagged_q_i = else_.add(ops.Break(EITHER_T)(q, else_.load(IntVal(1))))
else_.set_outputs(tagged_q_i)

# finish with Sum output from if-else, and bool from inputs
Expand Down
11 changes: 11 additions & 0 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,14 @@ def test_dfg_unpack() -> None:
dfg.set_outputs(*cond.outputs())

validate(dfg.hugr)


def test_option() -> None:
dfg = Dfg(tys.Bool)
b = dfg.inputs()[0]

dfg.add_op(ops.Some(tys.Bool), b)

dfg.set_outputs(b)

validate(dfg.hugr)

0 comments on commit f0385a0

Please sign in to comment.