Skip to content

Commit

Permalink
feat: put hugr.py behind an underscore
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed May 30, 2024
1 parent db0f4fc commit ae162d2
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
File renamed without changes.
135 changes: 135 additions & 0 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Generic, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp, OpType as SerialOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node


class Op(ABC):
@abstractmethod
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: ...

@classmethod
def from_serial(cls, serial: SerialOp) -> Op:
match serial.root:
case sops.Input(types=types):
return Input(types=types)
case sops.Output(types=types):
return Output(types=types)
case sops.CustomOp(
extension=extension,
op_name=op_name,
signature=signature,
description=description,
args=args,
):
return Custom(
extension=extension,
op_name=op_name,
signature=signature,
description=description,
args=args,
)
case sops.MakeTuple(tys=types):
return MakeTuple(types=types)
case sops.UnpackTuple(tys=types):
return UnpackTuple(types=types)
case sops.DFG(signature=signature):
return DFG(signature=signature)
return SerWrap(serial.root)


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
root = self._serial_op.model_copy()
root.parent = parent.idx
return SerialOp(root=root) # type: ignore


@dataclass()
class Input(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=sops.Input(parent=parent.idx, types=self.types))


@dataclass()
class Output(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=sops.Output(parent=parent.idx, types=self.types))


@dataclass()
class Custom(Op):
extension: tys.ExtensionId
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
args: list[tys.TypeArg] = field(default_factory=list)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
return SerialOp(
root=sops.CustomOp(
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
description=self.description,
args=self.args,
)
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
return SerialOp(
root=sops.MakeTuple(
parent=parent.idx,
tys=self.types,
)
)


@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
return SerialOp(
root=sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
)
)


@dataclass()
class DFG(Op):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp:
return SerialOp(
root=sops.DFG(
parent=parent.idx,
signature=self.signature,
)
)
2 changes: 1 addition & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess
import os
import pathlib
from hugr.hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op
from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op
from hugr.serialization import SerialHugr
import hugr.serialization.tys as stys
import hugr.serialization.ops as sops
Expand Down

0 comments on commit ae162d2

Please sign in to comment.