diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index e905917d7..a01ed5237 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -131,7 +131,15 @@ def add(self, com: ops.Command) -> Node: Node(3) """ - return self.add_op(com.op, *com.incoming) + + def raise_no_ints(): + error_message = "Command used with Dfg must hold Wire, not integer indices." + raise ValueError(error_message) + + wires = ( + (w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming + ) + return self.add_op(com.op, *wires) def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 2d6432ff5..5719ab092 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -133,6 +133,9 @@ def _check_complete(op, v: V | None) -> V: return v +ComWire = Wire | int + + @dataclass(frozen=True) class Command: """A :class:`DataflowOp` and its incoming :class:`Wire ` @@ -146,7 +149,7 @@ class Command: """ op: DataflowOp - incoming: list[Wire] + incoming: list[ComWire] @dataclass() @@ -244,7 +247,7 @@ def to_serial(self, parent: Node) -> sops.MakeTuple: tys=ser_it(self.types), ) - def __call__(self, *elements: Wire) -> Command: + def __call__(self, *elements: ComWire) -> Command: return super().__call__(*elements) def outer_signature(self) -> tys.FunctionType: @@ -282,7 +285,7 @@ def to_serial(self, parent: Node) -> sops.UnpackTuple: tys=ser_it(self.types), ) - def __call__(self, tuple_: Wire) -> Command: + def __call__(self, tuple_: ComWire) -> Command: return super().__call__(tuple_) def outer_signature(self) -> tys.FunctionType: @@ -925,7 +928,7 @@ def to_serial(self, parent: Node) -> sops.CallIndirect: signature=self.signature.to_serial(), ) - def __call__(self, function: Wire, *args: Wire) -> Command: # type: ignore[override] + def __call__(self, function: ComWire, *args: ComWire) -> Command: # type: ignore[override] return super().__call__(function, *args) def outer_signature(self) -> tys.FunctionType: diff --git a/hugr-py/src/hugr/tracked_dfg.py b/hugr-py/src/hugr/tracked_dfg.py new file mode 100644 index 000000000..51ddc0d7b --- /dev/null +++ b/hugr-py/src/hugr/tracked_dfg.py @@ -0,0 +1,220 @@ +"""Dfg builder that allows tracking a set of wires and appending operations by index.""" + +from collections.abc import Iterable + +from hugr import tys +from hugr.dfg import Dfg +from hugr.node_port import Node, Wire +from hugr.ops import Command, ComWire + + +class TrackedDfg(Dfg): + """Dfg builder to append operations to wires by index. + + Args: + *input_types: Input types of the Dfg. + track_inputs: Whether to track the input wires. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True) + >>> dfg.tracked + [OutPort(Node(1), 0), OutPort(Node(1), 1)] + """ + + #: Tracked wires. None if index is no longer tracked. + tracked: list[Wire | None] + + def __init__(self, *input_types: tys.Type, track_inputs: bool = False) -> None: + super().__init__(*input_types) + self.tracked = list(self.inputs()) if track_inputs else [] + + def track_wire(self, wire: Wire) -> int: + """Add a wire from this DFG to the tracked wires, and return its index. + + Args: + wire: Wire to track. + + Returns: + Index of the tracked wire. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit) + >>> dfg.track_wire(dfg.inputs()[0]) + 0 + """ + self.tracked.append(wire) + return len(self.tracked) - 1 + + def untrack_wire(self, index: int) -> Wire: + """Untrack a wire by index and return it. + + Args: + index: Index of the wire to untrack. + + Returns: + Wire that was untracked. + + Raises: + IndexError: If the index is not a tracked wire. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit) + >>> w = dfg.inputs()[0] + >>> idx = dfg.track_wire(w) + >>> dfg.untrack_wire(idx) == w + True + """ + w = self.tracked_wire(index) + self.tracked[index] = None + return w + + def track_wires(self, wires: Iterable[Wire]) -> list[int]: + """Set a list of wires to be tracked and return their indices. + + Args: + wires: Wires to track. + + Returns: + List of indices of the tracked wires. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit) + >>> dfg.track_wires(dfg.inputs()) + [0, 1] + """ + return [self.track_wire(w) for w in wires] + + def track_inputs(self) -> list[int]: + """Track all input wires and return their indices. + + Returns: + List of indices of the tracked input wires. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit) + >>> dfg.track_inputs() + [0, 1] + """ + return self.track_wires(self.inputs()) + + def tracked_wire(self, index: int) -> Wire: + """Get the tracked wire at the given index. + + Args: + index: Index of the tracked wire. + + Raises: + IndexError: If the index is not a tracked wire. + + Returns: + Tracked wire + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True) + >>> dfg.tracked_wire(0) == dfg.inputs()[0] + True + """ + try: + tracked = self.tracked[index] + except IndexError: + tracked = None + if tracked is None: + msg = f"Index {index} not a tracked wire." + raise IndexError(msg) + return tracked + + def append(self, com: Command) -> Node: + """Add a command to the DFG. + + Any incoming :class:`Wire ` will + be connected directly, while any integer will be treated as a reference + to the tracked wire at that index. + + Any tracked wires will be updated to the output of the new node at the same port + as the incoming index. + + Args: + com: Command to append. + + Returns: + The new node. + + Raises: + IndexError: If any input index is not a tracked wire. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, track_inputs=True) + >>> dfg.tracked + [OutPort(Node(1), 0)] + >>> dfg.append(ops.Noop()(0)) + Node(3) + >>> dfg.tracked + [OutPort(Node(3), 0)] + """ + wires = self._to_wires(com.incoming) + n = self.add_op(com.op, *wires) + + for port_offset, com_wire in enumerate(com.incoming): + if isinstance(com_wire, int): + tracked_idx = com_wire + else: + continue + # update tracked wires to matching port outputs of new node + self.tracked[tracked_idx] = n.out(port_offset) + + return n + + def _to_wires(self, in_wires: Iterable[ComWire]) -> Iterable[Wire]: + return ( + self.tracked_wire(inc) if isinstance(inc, int) else inc for inc in in_wires + ) + + def extend(self, coms: Iterable[Command]) -> list[Node]: + """Add a series of commands to the DFG. + + Shorthand for calling :meth:`append` on each command in `coms`. + + Args: + coms: Commands to append. + + Returns: + List of the new nodes in the same order as the commands. + + Raises: + IndexError: If any input index is not a tracked wire. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True) + >>> dfg.extend([ops.Noop()(0), ops.Noop()(1)]) + [Node(3), Node(4)] + """ + return [self.append(com) for com in coms] + + def set_indexed_outputs(self, *in_wires: ComWire) -> None: + """Set the Dfg outputs, using either :class:`Wire ` or + indices to tracked wires. + + Args: + *in_wires: Wires/indices to set as outputs. + + Raises: + IndexError: If any input index is not a tracked wire. + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit) + >>> (b, i) = dfg.inputs() + >>> dfg.track_wire(b) + 0 + >>> dfg.set_indexed_outputs(0, i) + """ + self.set_outputs(*self._to_wires(in_wires)) + + def set_tracked_outputs(self) -> None: + """Set the Dfg outputs to the tracked wires. + + + Examples: + >>> dfg = TrackedDfg(tys.Bool, tys.Unit, track_inputs=True) + >>> dfg.set_tracked_outputs() + """ + self.set_outputs(*(w for w in self.tracked if w is not None)) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index f1eff88b7..d72eb46e1 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -13,7 +13,7 @@ from hugr.serialization.serial_hugr import SerialHugr if TYPE_CHECKING: - from hugr.node_port import Wire + from hugr.ops import ComWire def int_t(width: int) -> tys.Opaque: @@ -36,6 +36,22 @@ def to_value(self) -> val.Extension: return val.Extension("int", INT_T, self.v) +FLOAT_T = tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, +) + + +@dataclass +class FloatVal(val.ExtensionValue): + v: float + + def to_value(self) -> val.Extension: + return val.Extension("float", FLOAT_T, self.v) + + @dataclass class LogicOps(Custom): extension: tys.ExtensionId = "logic" @@ -51,7 +67,7 @@ class NotDef(LogicOps): op_name: str = "Not" signature: tys.FunctionType = _NotSig - def __call__(self, a: Wire) -> Command: + def __call__(self, a: ComWire) -> Command: return super().__call__(a) @@ -72,12 +88,28 @@ class OneQbGate(QuantumOps): num_out: int = 1 signature: tys.FunctionType = _OneQbSig - def __call__(self, q: Wire) -> Command: + def __call__(self, q: ComWire) -> Command: return super().__call__(q) H = OneQbGate("H") + +_TwoQbSig = tys.FunctionType.endo([tys.Qubit] * 2) + + +@dataclass +class TwoQbGate(QuantumOps): + op_name: str + num_out: int = 2 + signature: tys.FunctionType = _TwoQbSig + + def __call__(self, q0: ComWire, q1: ComWire) -> Command: + return super().__call__(q0, q1) + + +CX = TwoQbGate("CX") + _MeasSig = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) @@ -87,12 +119,27 @@ class MeasureDef(QuantumOps): num_out: int = 2 signature: tys.FunctionType = _MeasSig - def __call__(self, q: Wire) -> Command: + def __call__(self, q: ComWire) -> Command: return super().__call__(q) Measure = MeasureDef() +_RzSig = tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]) + + +@dataclass +class RzDef(QuantumOps): + op_name: str = "Rz" + num_out: int = 1 + signature: tys.FunctionType = _RzSig + + def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: + return super().__call__(q, fl_wire) + + +Rz = RzDef() + @dataclass class IntOps(Custom): diff --git a/hugr-py/tests/test_tracked_dfg.py b/hugr-py/tests/test_tracked_dfg.py new file mode 100644 index 000000000..c1dcbcc07 --- /dev/null +++ b/hugr-py/tests/test_tracked_dfg.py @@ -0,0 +1,68 @@ +import pytest + +from hugr import tys +from hugr.tracked_dfg import TrackedDfg + +from .conftest import CX, FLOAT_T, FloatVal, H, Measure, Not, Rz, validate + + +def test_track_wire(): + dfg = TrackedDfg(tys.Bool, tys.Unit) + inds = dfg.track_inputs() + assert inds == [0, 1] + assert dfg.tracked_wire(inds[0]) == dfg.inputs()[0] + with pytest.raises(IndexError, match="Index 2 not a tracked wire."): + dfg.tracked_wire(2) + w1 = dfg.tracked_wire(inds[1]) + w1_removed = dfg.untrack_wire(inds[1]) + assert w1 == w1_removed + with pytest.raises(IndexError, match="Index 1 not a tracked wire."): + dfg.tracked_wire(inds[1]) + + dfg.set_indexed_outputs(0) + + validate(dfg.hugr) + + +def simple_circuit(n_qb: int, float_in: int = 0) -> TrackedDfg: + in_tys = [tys.Qubit] * n_qb + [FLOAT_T] * float_in + return TrackedDfg(*in_tys, track_inputs=True) + + +def test_simple_circuit(): + circ = simple_circuit(2) + circ.append(H(0)) + [_h, cx_n] = circ.extend([H(0), CX(0, 1)]) + + circ.set_tracked_outputs() + + assert len(circ.hugr) == 6 + + # all nodes connected to output + out_ins = { + out.node + for _, outs in circ.hugr.incoming_links(circ.output_node) + for out in outs + } + assert out_ins == {cx_n} + validate(circ.hugr) + + +def test_complex_circuit(): + circ = simple_circuit(2) + fl = circ.load(FloatVal(0.5)) + + circ.extend([H(0), Rz(0, fl)]) + [_m0, m1] = circ.extend(Measure(i) for i in range(2)) + + m_idx = circ.track_wire(m1[1]) # track the bool out + assert m_idx == 2 + circ.append(Not(m_idx)) + + circ.set_tracked_outputs() + + assert len(circ.hugr) == 10 + + assert circ._output_op().types == [tys.Qubit, tys.Qubit, tys.Bool] + + validate(circ.hugr)