diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 2f3a10eec..8fd42f846 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -12,7 +12,16 @@ import hugr._ops as ops import hugr._val as val -from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind +from hugr._tys import ( + Type, + TypeRow, + get_first_sum, + FunctionType, + TypeArg, + FunctionKind, + PolyFuncType, + ExtensionSet, +) from ._exceptions import NoSiblingAncestor from ._hugr import Hugr, ParentBuilder @@ -170,15 +179,9 @@ def call( func: ToNode, *args: Wire, instantiation: FunctionType | None = None, - type_args: list[TypeArg] | None = None, + type_args: Sequence[TypeArg] | None = None, ) -> Node: - f_op = self.hugr[func] - f_kind = f_op.op.port_kind(func.out(0)) - match f_kind: - case FunctionKind(sig): - signature = sig - case _: - raise ValueError("Expected 'func' to be a function") + signature = self._fn_sig(func) call_op = ops.Call(signature, instantiation, type_args) call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out) self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset())) @@ -187,6 +190,29 @@ def call( return call_n + def load_function( + self, + func: ToNode, + instantiation: FunctionType | None = None, + type_args: Sequence[TypeArg] | None = None, + ) -> Node: + signature = self._fn_sig(func) + load_op = ops.LoadFunc(signature, instantiation, type_args) + load_n = self.hugr.add_node(load_op, self.parent_node) + self.hugr.add_link(func.out(0), load_n.inp(0)) + + return load_n + + def _fn_sig(self, func: ToNode) -> PolyFuncType: + f_op = self.hugr[func] + f_kind = f_op.op.port_kind(func.out(0)) + match f_kind: + case FunctionKind(sig): + signature = sig + case _: + raise ValueError("Expected 'func' to be a function") + return signature + def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): @@ -212,8 +238,10 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: class Dfg(_DfBase[ops.DFG]): - def __init__(self, *input_types: Type) -> None: - parent_op = ops.DFG(list(input_types)) + def __init__( + self, *input_types: Type, extension_delta: ExtensionSet | None = None + ) -> None: + parent_op = ops.DFG(list(input_types), None, extension_delta or []) super().__init__(parent_op) diff --git a/hugr-py/src/hugr/_function.py b/hugr-py/src/hugr/_function.py index d8ffa578b..c1c54984c 100644 --- a/hugr-py/src/hugr/_function.py +++ b/hugr-py/src/hugr/_function.py @@ -8,7 +8,7 @@ from ._dfg import _DfBase from hugr._node_port import Node from ._hugr import Hugr -from ._tys import TypeRow, TypeParam, PolyFuncType +from ._tys import TypeRow, TypeParam, PolyFuncType, Type, TypeBound @dataclass @@ -47,3 +47,9 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: def add_const(self, value: val.Value) -> Node: return self.hugr.add_node(ops.Const(value), self.hugr.root) + + def add_alias_defn(self, name: str, ty: Type) -> Node: + return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root) + + def add_alias_decl(self, name: str, bound: TypeBound) -> Node: + return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 5bc63c79a..a2b68153c 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Protocol, TYPE_CHECKING, runtime_checkable, TypeVar +from typing import Protocol, TYPE_CHECKING, Sequence, runtime_checkable, TypeVar from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops from hugr.utils import ser_it @@ -233,6 +233,7 @@ def _inputs(self) -> tys.TypeRow: ... class DFG(DfParentOp, DataflowOp): inputs: tys.TypeRow _outputs: tys.TypeRow | None = None + extension_delta: tys.ExtensionSet = field(default_factory=list) @property def outputs(self) -> tys.TypeRow: @@ -240,7 +241,7 @@ def outputs(self) -> tys.TypeRow: @property def signature(self) -> tys.FunctionType: - return tys.FunctionType(self.inputs, self.outputs) + return tys.FunctionType(self.inputs, self.outputs, self.extension_delta) @property def num_out(self) -> int | None: @@ -381,6 +382,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: @dataclass class LoadConst(DataflowOp): typ: tys.Type | None = None + num_out: int | None = 1 def type_(self) -> tys.Type: return _check_complete(self.typ) @@ -588,6 +590,25 @@ class NoConcreteFunc(Exception): pass +def _fn_instantiation( + signature: tys.PolyFuncType, + instantiation: tys.FunctionType | None = None, + type_args: Sequence[tys.TypeArg] | None = None, +) -> tuple[tys.FunctionType, list[tys.TypeArg]]: + if len(signature.params) == 0: + return signature.body, [] + + else: + # TODO substitute type args into signature to get instantiation + if instantiation is None: + raise NoConcreteFunc("Missing instantiation for polymorphic function.") + type_args = type_args or [] + + if len(signature.params) != len(type_args): + raise NoConcreteFunc("Mismatched number of type arguments.") + return instantiation, list(type_args) + + @dataclass class Call(Op): signature: tys.PolyFuncType @@ -598,23 +619,12 @@ def __init__( self, signature: tys.PolyFuncType, instantiation: tys.FunctionType | None = None, - type_args: list[tys.TypeArg] | None = None, + type_args: Sequence[tys.TypeArg] | None = None, ) -> None: self.signature = signature - if len(signature.params) == 0: - self.instantiation = signature.body - self.type_args = [] - - else: - # TODO substitute type args into signature to get instantiation - if instantiation is None: - raise NoConcreteFunc("Missing instantiation for polymorphic function.") - type_args = type_args or [] - - if len(signature.params) != len(type_args): - raise NoConcreteFunc("Mismatched number of type arguments.") - self.instantiation = instantiation - self.type_args = type_args + self.instantiation, self.type_args = _fn_instantiation( + signature, instantiation, type_args + ) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call: return sops.Call( @@ -637,3 +647,161 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.FunctionKind(self.signature) case _: return tys.ValueKind(_sig_port_type(self.instantiation, port)) + + +@dataclass() +class CallIndirectDef(DataflowOp, PartialOp): + _signature: tys.FunctionType | None = None + + @property + def num_out(self) -> int | None: + return len(self.signature.output) + + @property + def signature(self) -> tys.FunctionType: + return _check_complete(self._signature) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CallIndirect: + return sops.CallIndirect( + parent=parent.idx, + signature=self.signature.to_serial(), + ) + + def __call__(self, function: Wire, *args: Wire) -> Command: # type: ignore[override] + return super().__call__(function, *args) + + def outer_signature(self) -> tys.FunctionType: + sig = self.signature + + return tys.FunctionType(input=[sig, *sig.input], output=sig.output) + + def set_in_types(self, types: tys.TypeRow) -> None: + func_sig, *_ = types + assert isinstance( + func_sig, tys.FunctionType + ), f"Expected function type, got {func_sig}" + self._signature = func_sig + + +# rename to eval? +CallIndirect = CallIndirectDef() + + +@dataclass +class LoadFunc(DataflowOp): + signature: tys.PolyFuncType + instantiation: tys.FunctionType + type_args: list[tys.TypeArg] + num_out: int | None = 1 + + def __init__( + self, + signature: tys.PolyFuncType, + instantiation: tys.FunctionType | None = None, + type_args: Sequence[tys.TypeArg] | None = None, + ) -> None: + self.signature = signature + self.instantiation, self.type_args = _fn_instantiation( + signature, instantiation, type_args + ) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadFunction: + return sops.LoadFunction( + parent=parent.idx, + func_sig=self.signature.to_serial(), + type_args=ser_it(self.type_args), + signature=self.outer_signature().to_serial(), + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=[], output=[self.instantiation]) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + match port: + case InPort(_, 0): + return tys.FunctionKind(self.signature) + case OutPort(_, 0): + return tys.ValueKind(self.instantiation) + case _: + raise InvalidPort(port) + + +@dataclass +class NoopDef(DataflowOp, PartialOp): + _type: tys.Type | None = None + num_out: int | None = 1 + + @property + def type_(self) -> tys.Type: + return _check_complete(self._type) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Noop: + return sops.Noop(parent=parent.idx, ty=self.type_.to_serial_root()) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType.endo([self.type_]) + + def set_in_types(self, types: tys.TypeRow) -> None: + (t,) = types + self._type = t + + +Noop = NoopDef() + + +@dataclass +class Lift(DataflowOp, PartialOp): + new_extension: tys.ExtensionId + _type_row: tys.TypeRow | None = None + num_out: int | None = 1 + + @property + def type_row(self) -> tys.TypeRow: + return _check_complete(self._type_row) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Lift: + return sops.Lift( + parent=parent.idx, + new_extension=self.new_extension, + type_row=ser_it(self.type_row), + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType.endo(self.type_row) + + def set_in_types(self, types: tys.TypeRow) -> None: + self._type_row = types + + +@dataclass +class AliasDecl(Op): + name: str + bound: tys.TypeBound + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDecl: + return sops.AliasDecl( + parent=parent.idx, + name=self.name, + bound=self.bound, + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise InvalidPort(port) + + +@dataclass +class AliasDefn(Op): + name: str + definition: tys.Type + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDefn: + return sops.AliasDefn( + parent=parent.idx, + name=self.name, + definition=self.definition.to_serial_root(), + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise InvalidPort(port) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 6e2e3584e..cfbb7c294 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -234,7 +234,11 @@ class FunctionType(Type): extension_reqs: ExtensionSet = field(default_factory=ExtensionSet) def to_serial(self) -> stys.FunctionType: - return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output)) + return stys.FunctionType( + input=ser_it(self.input), + output=ser_it(self.output), + extension_reqs=self.extension_reqs, + ) @classmethod def empty(cls) -> FunctionType: diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index a81ca95ba..b0bcfb1ab 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -43,9 +43,9 @@ def display_name(self) -> str: """Name of the op for visualisation""" return self.__class__.__name__ + @abstractmethod def deserialize(self) -> _ops.Op: """Deserializes the model into the corresponding Op.""" - raise NotImplementedError # ---------------------------------------------------------- @@ -334,6 +334,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(fun_ty.output) == len(out_types) self.signature = fun_ty + def deserialize(self) -> _ops.CallIndirectDef: + return _ops.CallIndirectDef(self.signature.deserialize()) + class LoadConstant(DataflowOp): """An operation that loads a static constant in to the local dataflow graph.""" @@ -353,6 +356,19 @@ class LoadFunction(DataflowOp): type_args: list[tys.TypeArg] signature: FunctionType + def deserialize(self) -> _ops.LoadFunc: + signature = self.signature.deserialize() + assert len(signature.input) == 0 + (f_ty,) = signature.output + assert isinstance( + f_ty, _tys.FunctionType + ), "Expected single funciton type output" + return _ops.LoadFunc( + self.func_sig.deserialize(), + f_ty, + deser_it(self.type_args), + ) + class DFG(DataflowOp): """A simply nested dataflow graph.""" @@ -367,7 +383,7 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: def deserialize(self) -> _ops.DFG: sig = self.signature.deserialize() - return _ops.DFG(sig.input, sig.output) + return _ops.DFG(sig.input, sig.output, sig.extension_reqs) # ------------------------------------------------ @@ -520,6 +536,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert in_types[0] == out_types[0] self.ty = in_types[0] + def deserialize(self) -> _ops.NoopDef: + return _ops.NoopDef(self.ty.deserialize()) + class MakeTuple(DataflowOp): """An operation that packs all its inputs into a tuple.""" @@ -571,18 +590,30 @@ class Lift(DataflowOp): type_row: TypeRow new_extension: ExtensionId + def deserialize(self) -> _ops.Lift: + return _ops.Lift( + _type_row=deser_it(self.type_row), + new_extension=self.new_extension, + ) + class AliasDecl(BaseOp): op: Literal["AliasDecl"] = "AliasDecl" name: str bound: TypeBound + def deserialize(self) -> _ops.AliasDecl: + return _ops.AliasDecl(self.name, self.bound) + class AliasDefn(BaseOp): op: Literal["AliasDefn"] = "AliasDefn" name: str definition: Type + def deserialize(self) -> _ops.AliasDefn: + return _ops.AliasDefn(self.name, self.definition.deserialize()) + class OpType(RootModel): """A constant operation.""" diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index cad65436d..5d54cb177 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -318,7 +318,8 @@ def test_vals(val: val.Value): _validate(d.hugr) -def test_poly_function() -> None: +@pytest.mark.parametrize("direct_call", [True, False]) +def test_poly_function(direct_call: bool) -> None: mod = Module() f_id = mod.declare_function( "id", @@ -330,21 +331,28 @@ def test_poly_function() -> None: f_main = mod.define_main([tys.Qubit]) q = f_main.input_node[0] - with pytest.raises(NoConcreteFunc, match="Missing instantiation"): - f_main.call(f_id, q) - call = f_main.call( - f_id, - q, - # for now concrete instantiations have to be provided. - instantiation=tys.FunctionType.endo([tys.Qubit]), - type_args=[tys.Qubit.type_arg()], - ) + # for now concrete instantiations have to be provided. + instantiation = tys.FunctionType.endo([tys.Qubit]) + type_args = [tys.Qubit.type_arg()] + if direct_call: + with pytest.raises(NoConcreteFunc, match="Missing instantiation"): + f_main.call(f_id, q) + call = f_main.call(f_id, q, instantiation=instantiation, type_args=type_args) + else: + with pytest.raises(NoConcreteFunc, match="Missing instantiation"): + f_main.load_function(f_id) + load = f_main.load_function( + f_id, instantiation=instantiation, type_args=type_args + ) + call = f_main.add(ops.CallIndirect(load, q)) + f_main.set_outputs(call) _validate(mod.hugr, True) -def test_mono_function() -> None: +@pytest.mark.parametrize("direct_call", [True, False]) +def test_mono_function(direct_call: bool) -> None: mod = Module() f_id = mod.define_function("id", [tys.Qubit]) f_id.set_outputs(f_id.input_node[0]) @@ -352,7 +360,40 @@ def test_mono_function() -> None: f_main = mod.define_main([tys.Qubit]) q = f_main.input_node[0] # monomorphic functions don't need instantiation specified - call = f_main.call(f_id, q) + if direct_call: + call = f_main.call(f_id, q) + else: + load = f_main.load_function(f_id) + call = f_main.add(ops.CallIndirect(load, q)) f_main.set_outputs(call) - _validate(mod.hugr, True) + _validate(mod.hugr) + + +def test_higher_order() -> None: + noop_fn = Dfg(tys.Qubit) + noop_fn.set_outputs(noop_fn.add(ops.Noop(noop_fn.input_node[0]))) + + d = Dfg(tys.Qubit) + (q,) = d.inputs() + f_val = d.load(val.Function(noop_fn.hugr)) + call = d.add(ops.CallIndirect(f_val, q))[0] + d.set_outputs(call) + + _validate(d.hugr) + + +def test_lift() -> None: + d = Dfg(tys.Qubit, extension_delta=["X"]) + (q,) = d.inputs() + lift = d.add(ops.Lift("X")(q)) + d.set_outputs(lift) + _validate(d.hugr) + + +def test_alias() -> None: + mod = Module() + _dfn = mod.add_alias_defn("my_int", INT_T) + _dcl = mod.add_alias_decl("my_bool", tys.TypeBound.Eq) + + _validate(mod.hugr)