Skip to content

Commit

Permalink
rename ToPort to Wire
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed May 28, 2024
1 parent f87dcf7 commit 83a5b19
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
48 changes: 24 additions & 24 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass, field, replace

from collections.abc import Mapping
from enum import Enum
from typing import (
Expand Down Expand Up @@ -29,7 +29,7 @@ class Direction(Enum):

@dataclass(frozen=True, eq=True, order=True)
class _Port:
node: "Node"
node: Node
offset: int


Expand All @@ -38,20 +38,20 @@ class InPort(_Port):
direction: ClassVar[Direction] = Direction.INCOMING


class ToPort(Protocol):
def to_port(self) -> "OutPort": ...
class Wire(Protocol):
def out_port(self) -> OutPort: ...


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, ToPort):
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING

def to_port(self) -> "OutPort":
def out_port(self) -> OutPort:
return self


@dataclass(frozen=True, eq=True, order=True)
class Node(ToPort):
class Node(Wire):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

Expand Down Expand Up @@ -83,7 +83,7 @@ def __getitem__(
case tuple(xs):
return [self[i] for i in xs]

def to_port(self) -> "OutPort":
def out_port(self) -> "OutPort":
return OutPort(self, 0)

def inp(self, offset: int) -> InPort:
Expand All @@ -100,7 +100,7 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort:


class Op(Protocol):
def to_serial(self, node: Node, hugr: "Hugr") -> SerialOp: ...
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ...


T = TypeVar("T", bound=BaseOp)
Expand All @@ -110,13 +110,13 @@ def to_serial(self, node: Node, hugr: "Hugr") -> SerialOp: ...
class DummyOp(Op, Generic[T]):
_serial_op: T

def to_serial(self, node: Node, hugr: "Hugr") -> SerialOp:
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=self._serial_op) # type: ignore


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[ToPort]: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None

Expand All @@ -129,7 +129,7 @@ class NodeData:
_num_outs: int = 0
# TODO children field?

def to_serial(self, node: Node, hugr: "Hugr") -> SerialOp:
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, hugr)
o.root.parent = self.parent.idx if self.parent else node.idx

Expand Down Expand Up @@ -305,7 +305,7 @@ def num_outgoing(self, node: Node) -> int:

# TODO: num_links and _linked_ports

def insert_hugr(self, hugr: "Hugr", parent: Node | None = None) -> dict[Node, Node]:
def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node]:
mapping: dict[Node, Node] = {}

for idx, node_data in enumerate(hugr._nodes):
Expand Down Expand Up @@ -340,7 +340,7 @@ def to_serial(self) -> SerialHugr:
)

@classmethod
def from_serial(cls, serial: SerialHugr) -> "Hugr":
def from_serial(cls, serial: SerialHugr) -> Hugr:
raise NotImplementedError


Expand Down Expand Up @@ -371,7 +371,7 @@ def __init__(
)

@classmethod
def endo(cls, types: Sequence[Type]) -> "Dfg":
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> DummyOp[sops.Input]:
Expand All @@ -386,15 +386,15 @@ def inputs(self) -> list[OutPort]:
for i in range(len(self._input_op()._serial_op.types))
]

def add_op(self, op: Op, /, *args: ToPort, num_outs: int | None = None) -> Node:
def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out())

def insert_nested(self, dfg: "Dfg", *args: ToPort) -> Node:
def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]
Expand All @@ -403,8 +403,8 @@ def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
ports: Iterable[ToPort],
) -> "Dfg":
ports: Iterable[Wire],
) -> Dfg:
dfg = Dfg(input_types, output_types)
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], ports)
Expand All @@ -415,21 +415,21 @@ def add_nested(

return dfg

def set_outputs(self, *args: ToPort) -> None:
def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)

def make_tuple(self, tys: Sequence[Type], *args: ToPort) -> Node:
def make_tuple(self, tys: Sequence[Type], *args: Wire) -> Node:
ports = list(args)
assert len(tys) == len(ports), "Number of types must match number of ports"
return self.add_op(DummyOp(sops.MakeTuple(parent=0, tys=list(tys))), *args)

def split_tuple(self, tys: Sequence[Type], port: ToPort) -> list[OutPort]:
def split_tuple(self, tys: Sequence[Type], port: Wire) -> list[OutPort]:
tys = list(tys)
n = self.add_op(DummyOp(sops.UnpackTuple(parent=0, tys=tys)), port)

return [n.out(i) for i in range(len(tys))]

def _wire_up(self, node: Node, ports: Iterable[ToPort]):
def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.to_port()
src = p.out_port()
self.hugr.add_link(src, node.inp(i))
12 changes: 6 additions & 6 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess
import os
import pathlib
from hugr.hugr import Dfg, Hugr, DummyOp, Node, Command, ToPort, Op
from hugr.hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op
import hugr.serialization.tys as stys
import hugr.serialization.ops as sops
import pytest
Expand Down Expand Up @@ -32,9 +32,9 @@

@dataclass
class Not(Command):
a: ToPort
a: Wire

def incoming(self) -> list[ToPort]:
def incoming(self) -> list[Wire]:
return [self.a]

def num_out(self) -> int | None:
Expand All @@ -46,10 +46,10 @@ def op(self) -> Op:

@dataclass
class DivMod(Command):
a: ToPort
b: ToPort
a: Wire
b: Wire

def incoming(self) -> list[ToPort]:
def incoming(self) -> list[Wire]:
return [self.a, self.b]

def num_out(self) -> int | None:
Expand Down

0 comments on commit 83a5b19

Please sign in to comment.