Skip to content

Commit

Permalink
feat(hugr-py): add CallIndirect, LoadFunction, Lift, Alias (#1218)
Browse files Browse the repository at this point in the history
all ops covered!
commits are independent 
Closes #1213
  • Loading branch information
ss2165 authored Jun 24, 2024
1 parent af062ea commit db09193
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 45 deletions.
50 changes: 39 additions & 11 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind
from hugr._tys import (
Type,
TypeRow,
get_first_sum,
FunctionType,
TypeArg,
FunctionKind,
PolyFuncType,
ExtensionSet,
)

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, ParentBuilder
Expand Down Expand Up @@ -170,15 +179,9 @@ def call(
func: ToNode,
*args: Wire,
instantiation: FunctionType | None = None,
type_args: list[TypeArg] | None = None,
type_args: Sequence[TypeArg] | None = None,
) -> Node:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError("Expected 'func' to be a function")
signature = self._fn_sig(func)
call_op = ops.Call(signature, instantiation, type_args)
call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out)
self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset()))
Expand All @@ -187,6 +190,29 @@ def call(

return call_n

def load_function(
self,
func: ToNode,
instantiation: FunctionType | None = None,
type_args: Sequence[TypeArg] | None = None,
) -> Node:
signature = self._fn_sig(func)
load_op = ops.LoadFunc(signature, instantiation, type_args)
load_n = self.hugr.add_node(load_op, self.parent_node)
self.hugr.add_link(func.out(0), load_n.inp(0))

return load_n

def _fn_sig(self, func: ToNode) -> PolyFuncType:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError("Expected 'func' to be a function")
return signature

def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand All @@ -212,8 +238,10 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:


class Dfg(_DfBase[ops.DFG]):
def __init__(self, *input_types: Type) -> None:
parent_op = ops.DFG(list(input_types))
def __init__(
self, *input_types: Type, extension_delta: ExtensionSet | None = None
) -> None:
parent_op = ops.DFG(list(input_types), None, extension_delta or [])
super().__init__(parent_op)


Expand Down
8 changes: 7 additions & 1 deletion hugr-py/src/hugr/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._dfg import _DfBase
from hugr._node_port import Node
from ._hugr import Hugr
from ._tys import TypeRow, TypeParam, PolyFuncType
from ._tys import TypeRow, TypeParam, PolyFuncType, Type, TypeBound


@dataclass
Expand Down Expand Up @@ -47,3 +47,9 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node:

def add_const(self, value: val.Value) -> Node:
return self.hugr.add_node(ops.Const(value), self.hugr.root)

def add_alias_defn(self, name: str, ty: Type) -> Node:
return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root)

def add_alias_decl(self, name: str, bound: TypeBound) -> Node:
return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root)
202 changes: 185 additions & 17 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Protocol, TYPE_CHECKING, runtime_checkable, TypeVar
from typing import Protocol, TYPE_CHECKING, Sequence, runtime_checkable, TypeVar
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
from hugr.utils import ser_it
Expand Down Expand Up @@ -233,14 +233,15 @@ def _inputs(self) -> tys.TypeRow: ...
class DFG(DfParentOp, DataflowOp):
inputs: tys.TypeRow
_outputs: tys.TypeRow | None = None
extension_delta: tys.ExtensionSet = field(default_factory=list)

@property
def outputs(self) -> tys.TypeRow:
return _check_complete(self._outputs)

@property
def signature(self) -> tys.FunctionType:
return tys.FunctionType(self.inputs, self.outputs)
return tys.FunctionType(self.inputs, self.outputs, self.extension_delta)

@property
def num_out(self) -> int | None:
Expand Down Expand Up @@ -381,6 +382,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind:
@dataclass
class LoadConst(DataflowOp):
typ: tys.Type | None = None
num_out: int | None = 1

def type_(self) -> tys.Type:
return _check_complete(self.typ)
Expand Down Expand Up @@ -588,6 +590,25 @@ class NoConcreteFunc(Exception):
pass


def _fn_instantiation(
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: Sequence[tys.TypeArg] | None = None,
) -> tuple[tys.FunctionType, list[tys.TypeArg]]:
if len(signature.params) == 0:
return signature.body, []

else:
# TODO substitute type args into signature to get instantiation
if instantiation is None:
raise NoConcreteFunc("Missing instantiation for polymorphic function.")
type_args = type_args or []

if len(signature.params) != len(type_args):
raise NoConcreteFunc("Mismatched number of type arguments.")
return instantiation, list(type_args)


@dataclass
class Call(Op):
signature: tys.PolyFuncType
Expand All @@ -598,23 +619,12 @@ def __init__(
self,
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: list[tys.TypeArg] | None = None,
type_args: Sequence[tys.TypeArg] | None = None,
) -> None:
self.signature = signature
if len(signature.params) == 0:
self.instantiation = signature.body
self.type_args = []

else:
# TODO substitute type args into signature to get instantiation
if instantiation is None:
raise NoConcreteFunc("Missing instantiation for polymorphic function.")
type_args = type_args or []

if len(signature.params) != len(type_args):
raise NoConcreteFunc("Mismatched number of type arguments.")
self.instantiation = instantiation
self.type_args = type_args
self.instantiation, self.type_args = _fn_instantiation(
signature, instantiation, type_args
)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call:
return sops.Call(
Expand All @@ -637,3 +647,161 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.FunctionKind(self.signature)
case _:
return tys.ValueKind(_sig_port_type(self.instantiation, port))


@dataclass()
class CallIndirectDef(DataflowOp, PartialOp):
_signature: tys.FunctionType | None = None

@property
def num_out(self) -> int | None:
return len(self.signature.output)

@property
def signature(self) -> tys.FunctionType:
return _check_complete(self._signature)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CallIndirect:
return sops.CallIndirect(
parent=parent.idx,
signature=self.signature.to_serial(),
)

def __call__(self, function: Wire, *args: Wire) -> Command: # type: ignore[override]
return super().__call__(function, *args)

def outer_signature(self) -> tys.FunctionType:
sig = self.signature

return tys.FunctionType(input=[sig, *sig.input], output=sig.output)

def set_in_types(self, types: tys.TypeRow) -> None:
func_sig, *_ = types
assert isinstance(
func_sig, tys.FunctionType
), f"Expected function type, got {func_sig}"
self._signature = func_sig


# rename to eval?
CallIndirect = CallIndirectDef()


@dataclass
class LoadFunc(DataflowOp):
signature: tys.PolyFuncType
instantiation: tys.FunctionType
type_args: list[tys.TypeArg]
num_out: int | None = 1

def __init__(
self,
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: Sequence[tys.TypeArg] | None = None,
) -> None:
self.signature = signature
self.instantiation, self.type_args = _fn_instantiation(
signature, instantiation, type_args
)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadFunction:
return sops.LoadFunction(
parent=parent.idx,
func_sig=self.signature.to_serial(),
type_args=ser_it(self.type_args),
signature=self.outer_signature().to_serial(),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=[self.instantiation])

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
match port:
case InPort(_, 0):
return tys.FunctionKind(self.signature)
case OutPort(_, 0):
return tys.ValueKind(self.instantiation)
case _:
raise InvalidPort(port)


@dataclass
class NoopDef(DataflowOp, PartialOp):
_type: tys.Type | None = None
num_out: int | None = 1

@property
def type_(self) -> tys.Type:
return _check_complete(self._type)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Noop:
return sops.Noop(parent=parent.idx, ty=self.type_.to_serial_root())

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType.endo([self.type_])

def set_in_types(self, types: tys.TypeRow) -> None:
(t,) = types
self._type = t


Noop = NoopDef()


@dataclass
class Lift(DataflowOp, PartialOp):
new_extension: tys.ExtensionId
_type_row: tys.TypeRow | None = None
num_out: int | None = 1

@property
def type_row(self) -> tys.TypeRow:
return _check_complete(self._type_row)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Lift:
return sops.Lift(
parent=parent.idx,
new_extension=self.new_extension,
type_row=ser_it(self.type_row),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType.endo(self.type_row)

def set_in_types(self, types: tys.TypeRow) -> None:
self._type_row = types


@dataclass
class AliasDecl(Op):
name: str
bound: tys.TypeBound
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDecl:
return sops.AliasDecl(
parent=parent.idx,
name=self.name,
bound=self.bound,
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise InvalidPort(port)


@dataclass
class AliasDefn(Op):
name: str
definition: tys.Type
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDefn:
return sops.AliasDefn(
parent=parent.idx,
name=self.name,
definition=self.definition.to_serial_root(),
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise InvalidPort(port)
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ class FunctionType(Type):
extension_reqs: ExtensionSet = field(default_factory=ExtensionSet)

def to_serial(self) -> stys.FunctionType:
return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output))
return stys.FunctionType(
input=ser_it(self.input),
output=ser_it(self.output),
extension_reqs=self.extension_reqs,
)

@classmethod
def empty(cls) -> FunctionType:
Expand Down
Loading

0 comments on commit db09193

Please sign in to comment.