Skip to content

Commit

Permalink
refactor: commands -> call on ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 4, 2024
1 parent e2af3df commit c2c8528
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 73 deletions.
11 changes: 2 additions & 9 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG
from hugr._ops import Op, Input, Output, DFG, Command
from hugr.utils import BiMap


Expand All @@ -44,13 +44,6 @@ class Wire(Protocol):
def out_port(self) -> OutPort: ...


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING
Expand Down Expand Up @@ -399,7 +392,7 @@ def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out())
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
Expand Down
31 changes: 29 additions & 2 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,25 @@
import hugr.serialization.tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node
from hugr._hugr import Hugr, Node, Wire

Check warning on line 10 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L10

Added line #L10 was not covered by tests


class Op(Protocol):
@property
def num_out(self) -> int | None:
return None

Check warning on line 16 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L16

Added line #L16 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...

def __call__(self, *args) -> Command:
return Command(self, list(args))


@dataclass(frozen=True)
class Command:
op: Op
incoming: list[Wire]


T = TypeVar("T", bound=BaseOp)

Expand All @@ -35,6 +48,9 @@ class Input(Op):
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)

def __call__(self) -> Command:
return super().__call__()

Check warning on line 52 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L52

Added line #L52 was not covered by tests


@dataclass()
class Output(Op):
Expand All @@ -46,10 +62,10 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:

@dataclass()
class Custom(Op):
extension: tys.ExtensionId
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
extension: tys.ExtensionId = ""
args: list[tys.TypeArg] = field(default_factory=list)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
Expand All @@ -66,24 +82,35 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, *elements: Wire) -> Command:
return super().__call__(*elements)


@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, tuple_: Wire) -> Command:
return super().__call__(tuple_)


@dataclass()
class DFG(Op):
Expand Down
93 changes: 31 additions & 62 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
import subprocess
import os
import pathlib
from hugr._hugr import Dfg, Hugr, Node, Command, Wire
from hugr._ops import Op, Custom
from hugr._hugr import Dfg, Hugr, Node, Wire
from hugr._ops import Custom, Command
import hugr._ops as ops
from hugr.serialization import SerialHugr
import hugr.serialization.tys as stys
Expand All @@ -23,76 +23,45 @@
)
)

# TODO get from YAML
NOT_OP = Custom(
extension="logic",
op_name="Not",
signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]),
)


@dataclass
class Not(Command):
a: Wire

def incoming(self) -> list[Wire]:
return [self.a]

def num_out(self) -> int | None:
return 1

def op(self) -> Op:
return NOT_OP
class LogicOps(Custom):
extension: stys.ExtensionId = "logic"


# TODO get from YAML
@dataclass
class DivMod(Command):
a: Wire
b: Wire
class NotDef(LogicOps):
num_out: int | None = 1
op_name: str = "Not"
signature: stys.FunctionType = field(
default_factory=lambda: stys.FunctionType(input=[BOOL_T], output=[BOOL_T])
)

def incoming(self) -> list[Wire]:
return [self.a, self.b]
def __call__(self, a: Wire) -> Command:
return super().__call__(a)

def num_out(self) -> int | None:
return 2

def op(self) -> Op:
return Custom(
extension="arithmetic.int",
op_name="idivmod_u",
signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2),
args=[ARG_5, ARG_5],
)
Not = NotDef()


@dataclass
class MakeTuple(Command):
types: list[stys.Type]
wires: list[Wire]

def incoming(self) -> list[Wire]:
return self.wires

def num_out(self) -> int | None:
return 1

def op(self) -> Op:
return ops.MakeTuple(self.types)
class IntOps(Custom):
extension: stys.ExtensionId = "arithmetic.int"


@dataclass
class UnpackTuple(Command):
types: list[stys.Type]
wire: Wire

def incoming(self) -> list[Wire]:
return [self.wire]
class DivModDef(IntOps):
num_out: int | None = 2
extension: stys.ExtensionId = "arithmetic.int"
op_name: str = "idivmod_u"
signature: stys.FunctionType = field(
default_factory=lambda: stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2)
)
args: list[stys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5])

def num_out(self) -> int | None:
return len(self.types)

def op(self) -> Op:
return ops.UnpackTuple(self.types)
DivMod = DivModDef()


def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True):
Expand All @@ -114,7 +83,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True):
def test_stable_indices():
h = Hugr(ops.DFG())

nodes = [h.add_node(NOT_OP) for _ in range(3)]
nodes = [h.add_node(Not) for _ in range(3)]
assert len(h) == 4

h.add_link(nodes[0].out(0), nodes[1].inp(0))
Expand All @@ -137,7 +106,7 @@ def test_stable_indices():
with pytest.raises(KeyError):
_ = h[Node(46)]

new_n = h.add_node(NOT_OP)
new_n = h.add_node(Not)
assert new_n == nodes[1]

assert len(h) == 4
Expand Down Expand Up @@ -178,7 +147,7 @@ def test_multiport():
def test_add_op():
h = Dfg.endo([BOOL_T])
(a,) = h.inputs()
nt = h.add_op(NOT_OP, a)
nt = h.add_op(Not, a)
h.set_outputs(nt)

_validate(h.hugr)
Expand All @@ -188,8 +157,8 @@ def test_tuple():
row = [BOOL_T, QB_T]
h = Dfg.endo(row)
a, b = h.inputs()
t = h.add(MakeTuple(row, [a, b]))
a, b = h.add(UnpackTuple(row, t))
t = h.add(ops.MakeTuple(row)(a, b))
a, b = h.add(ops.UnpackTuple(row)(t))
h.set_outputs(a, b)

_validate(h.hugr)
Expand Down

0 comments on commit c2c8528

Please sign in to comment.