diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/_hugr.py similarity index 100% rename from hugr-py/src/hugr/hugr.py rename to hugr-py/src/hugr/_hugr.py diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py new file mode 100644 index 000000000..c645d0664 --- /dev/null +++ b/hugr-py/src/hugr/ops.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Generic, TypeVar, TYPE_CHECKING +from hugr.serialization.ops import BaseOp, OpType as SerialOp +import hugr.serialization.ops as sops +import hugr.serialization.tys as tys + +if TYPE_CHECKING: + from hugr._hugr import Hugr, Node + + +class Op(ABC): + @abstractmethod + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: ... + + @classmethod + def from_serial(cls, serial: SerialOp) -> Op: + match serial.root: + case sops.Input(types=types): + return Input(types=types) + case sops.Output(types=types): + return Output(types=types) + case sops.CustomOp( + extension=extension, + op_name=op_name, + signature=signature, + description=description, + args=args, + ): + return Custom( + extension=extension, + op_name=op_name, + signature=signature, + description=description, + args=args, + ) + case sops.MakeTuple(tys=types): + return MakeTuple(types=types) + case sops.UnpackTuple(tys=types): + return UnpackTuple(types=types) + case sops.DFG(signature=signature): + return DFG(signature=signature) + return SerWrap(serial.root) + + +T = TypeVar("T", bound=BaseOp) + + +@dataclass() +class SerWrap(Op, Generic[T]): + # catch all for serial ops that don't have a corresponding Op class + _serial_op: T + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + root = self._serial_op.model_copy() + root.parent = parent.idx + return SerialOp(root=root) # type: ignore + + +@dataclass() +class Input(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + return SerialOp(root=sops.Input(parent=parent.idx, types=self.types)) + + +@dataclass() +class Output(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + return SerialOp(root=sops.Output(parent=parent.idx, types=self.types)) + + +@dataclass() +class Custom(Op): + extension: tys.ExtensionId + op_name: str + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + description: str = "" + args: list[tys.TypeArg] = field(default_factory=list) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + return SerialOp( + root=sops.CustomOp( + parent=parent.idx, + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + description=self.description, + args=self.args, + ) + ) + + +@dataclass() +class MakeTuple(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + return SerialOp( + root=sops.MakeTuple( + parent=parent.idx, + tys=self.types, + ) + ) + + +@dataclass() +class UnpackTuple(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + return SerialOp( + root=sops.UnpackTuple( + parent=parent.idx, + tys=self.types, + ) + ) + + +@dataclass() +class DFG(Op): + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: + return SerialOp( + root=sops.DFG( + parent=parent.idx, + signature=self.signature, + ) + ) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 830e5734e..522da06f6 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,7 +3,7 @@ import subprocess import os import pathlib -from hugr.hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op +from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op from hugr.serialization import SerialHugr import hugr.serialization.tys as stys import hugr.serialization.ops as sops