diff --git a/guppylang/checker/stmt_checker.py b/guppylang/checker/stmt_checker.py index 13ca25e6..b4f8db48 100644 --- a/guppylang/checker/stmt_checker.py +++ b/guppylang/checker/stmt_checker.py @@ -11,7 +11,7 @@ import ast from collections.abc import Sequence -from guppylang.ast_util import AstVisitor, with_loc +from guppylang.ast_util import AstVisitor, with_loc, with_type from guppylang.cfg.bb import BB, BBStatement from guppylang.checker.core import Context, FieldAccess, Variable from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer @@ -53,7 +53,7 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr: case ast.Name(id=x): var = Variable(x, ty, lhs) self.ctx.locals[x] = var - return with_loc(lhs, PlaceNode(place=var)) + return with_loc(lhs, with_type(ty, PlaceNode(place=var))) # The LHS could also be a field `expr.field` case ast.Attribute(value=value, attr=attr): @@ -93,7 +93,7 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr: "Mutation of classical fields is not supported yet", lhs ) place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs) - return with_loc(lhs, PlaceNode(place=place)) + return with_loc(lhs, with_type(ty, PlaceNode(place=place))) # The only other thing we support right now are tuples case ast.Tuple(elts=elts) as lhs: @@ -109,7 +109,7 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr: self._check_assign(pat, el_ty, node) for pat, el_ty in zip(elts, tys, strict=True) ] - return lhs + return with_type(ty, lhs) # TODO: Python also supports assignments like `[a, b] = [1, 2]` or # `a, *b = ...`. The former would require some runtime checks but diff --git a/guppylang/compiler/cfg_compiler.py b/guppylang/compiler/cfg_compiler.py index 8ae096de..fc28daee 100644 --- a/guppylang/compiler/cfg_compiler.py +++ b/guppylang/compiler/cfg_compiler.py @@ -1,5 +1,11 @@ import functools -from typing import TYPE_CHECKING +from collections.abc import Sequence + +from hugr import Wire, ops +from hugr import cfg as hc +from hugr import tys as ht +from hugr.dfg import DP, _DfBase +from hugr.node_port import ToNode from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Row, Signature from guppylang.checker.core import Place, Variable @@ -11,62 +17,68 @@ ) from guppylang.compiler.expr_compiler import ExprCompiler from guppylang.compiler.stmt_compiler import StmtCompiler -from guppylang.hugr_builder.hugr import CFNode, Hugr, Node, OutPortV -from guppylang.tys.builtin import is_bool_type from guppylang.tys.ty import SumType, row_to_type, type_to_row -if TYPE_CHECKING: - from collections.abc import Sequence - def compile_cfg( - cfg: CheckedCFG[Place], graph: Hugr, parent: Node, globals: CompiledGlobals -) -> None: + cfg: CheckedCFG[Place], + container: _DfBase[DP], + inputs: Sequence[Wire], + globals: CompiledGlobals, +) -> hc.Cfg: """Compiles a CFG to Hugr.""" insert_return_vars(cfg) - blocks: dict[CheckedBB[Place], CFNode] = {} + builder = container.add_cfg(*inputs) + + blocks: dict[CheckedBB[Place], ToNode] = {} for bb in cfg.bbs: - blocks[bb] = compile_bb(bb, graph, parent, bb == cfg.entry_bb, globals) + blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, globals) for bb in cfg.bbs: - for succ in bb.successors: - graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) + for i, succ in enumerate(bb.successors): + builder.branch(blocks[bb][i], blocks[succ]) + + return builder def compile_bb( bb: CheckedBB[Place], - graph: Hugr, - parent: Node, + builder: hc.Cfg, is_entry: bool, globals: CompiledGlobals, -) -> CFNode: - """Compiles a single basic block to Hugr.""" - inputs = bb.sig.input_row if is_entry else sort_vars(bb.sig.input_row) +) -> ToNode: + """Compiles a single basic block to Hugr, and returns the resulting block. + If the basic block is the output block, returns `None`. + """ # The exit BB is completely empty if len(bb.successors) == 0: assert len(bb.statements) == 0 - return graph.add_exit([v.ty for v in inputs], parent) + return builder.exit # Otherwise, we use a regular `Block` node - block = graph.add_block(parent) + block: hc.Block + inputs: Sequence[Place] + if is_entry: + inputs = bb.sig.input_row + block = builder.add_entry() + else: + inputs = sort_vars(bb.sig.input_row) + block = builder.add_block(*(v.ty.to_hugr() for v in inputs)) # Add input node and compile the statements - inp = graph.add_input(output_tys=[v.ty for v in inputs], parent=block) - dfg = DFContainer(graph, block) - for i, v in enumerate(inputs): - dfg[v] = inp.out_port(i) - dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, dfg) + dfg = DFContainer(block) + for v, wire in zip(inputs, block.input_node, strict=True): + dfg[v] = wire + dfg = StmtCompiler(globals).compile_stmts(bb.statements, dfg) # If we branch, we also have to compile the branch predicate if len(bb.successors) > 1: assert bb.branch_pred is not None - branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg) + branch_port = ExprCompiler(globals).compile(bb.branch_pred, dfg) else: # Even if we don't branch, we still have to add a `Sum(())` predicates - branch_port = graph.add_tag( - variants=[[]], tag=0, inputs=[], parent=block - ).out_port(0) + branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1))) # Finally, we have to add the block output. outputs: Sequence[Place] @@ -87,7 +99,6 @@ def compile_bb( # ordering on variables which puts linear variables at the end. The only # exception are return vars which must be outputted in order. branch_port = choose_vars_for_tuple_sum( - graph=graph, unit_sum=branch_port, output_vars=[ [ @@ -101,9 +112,7 @@ def compile_bb( ) outputs = [v for v in first if v.ty.linear and not is_return_var(str(v))] - graph.add_output( - inputs=[branch_port] + [dfg[v] for v in sort_vars(outputs)], parent=block - ) + block.set_block_outputs(branch_port, *(dfg[v] for v in sort_vars(outputs))) return block @@ -130,27 +139,23 @@ def insert_return_vars(cfg: CheckedCFG[Place]) -> None: def choose_vars_for_tuple_sum( - graph: Hugr, unit_sum: OutPortV, output_vars: list[Row[Place]], dfg: DFContainer -) -> OutPortV: + unit_sum: Wire, output_vars: list[Row[Place]], dfg: DFContainer +) -> Wire: """Selects an output based on a TupleSum. Given `unit_sum: Sum(*(), *(), ...)` and output variable rows `#s1, #s2, ...`, constructs a TupleSum value of type `Sum(#s1, #s2, ...)`. """ - assert isinstance(unit_sum.ty, SumType) or is_bool_type(unit_sum.ty) - assert len(output_vars) == ( - len(unit_sum.ty.element_types) if isinstance(unit_sum.ty, SumType) else 2 - ) assert all(not v.ty.linear for var_row in output_vars for v in var_row) - conditional = graph.add_conditional(cond_input=unit_sum, inputs=[], parent=dfg.node) tys = [[v.ty for v in var_row] for var_row in output_vars] - for i, var_row in enumerate(output_vars): - case = graph.add_case(conditional) - graph.add_input(output_tys=[], parent=case) - inputs = [dfg[v] for v in var_row] - tag = graph.add_tag(variants=tys, tag=i, inputs=inputs, parent=case).out_port(0) - graph.add_output(inputs=[tag], parent=case) - return conditional.add_out_port(SumType([row_to_type(row) for row in tys])) + sum_type = SumType([row_to_type(row) for row in tys]).to_hugr() + + with dfg.builder.add_conditional(unit_sum) as conditional: + for i, var_row in enumerate(output_vars): + with conditional.add_case(i) as case: + tag = case.add_op(ops.Tag(i, sum_type), *(dfg[v] for v in var_row)) + case.set_outputs(tag) + return conditional def compare_var(p1: Place, p2: Place) -> int: diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index 5110b65b..0abc642a 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -1,32 +1,43 @@ from abc import ABC from dataclasses import dataclass, field +from typing import cast + +from hugr import Wire, ops +from hugr.dfg import DP, _DfBase from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable from guppylang.definition.common import CompiledDef, DefId from guppylang.error import InternalGuppyError -from guppylang.hugr_builder.hugr import DFContainingNode, Hugr, OutPortV from guppylang.tys.ty import StructType CompiledGlobals = dict[DefId, CompiledDef] -CompiledLocals = dict[PlaceId, OutPortV] +CompiledLocals = dict[PlaceId, Wire] @dataclass class DFContainer: """A dataflow graph under construction. - This class is passed through the entire compilation pipeline and stores the node - whose dataflow child-graph is currently being constructed as well as all live local + This class is passed through the entire compilation pipeline and stores a builder + for the dataflow child-graph currently being constructed as well as all live local variables. Note that the variable map is mutated in-place and always reflects the current compilation state. """ - graph: Hugr - node: DFContainingNode + builder: _DfBase[ops.DfParentOp] locals: CompiledLocals = field(default_factory=dict) - def __getitem__(self, place: Place) -> OutPortV: - """Constructs a port for a local place in this DFG. + def __init__( + self, builder: _DfBase[DP], locals: CompiledLocals | None = None + ) -> None: + generic_builder = cast(_DfBase[ops.DfParentOp], builder) + if locals is None: + locals = {} + self.builder = generic_builder + self.locals = locals + + def __getitem__(self, place: Place) -> Wire: + """Constructs a wire for a local place in this DFG. Note that this mutates the Hugr since we might need to pack or unpack some tuples to obtain a port for places that involve struct fields. @@ -39,23 +50,24 @@ def __getitem__(self, place: Place) -> OutPortV: if not isinstance(place.ty, StructType): raise InternalGuppyError(f"Couldn't obtain a port for `{place}`") children = [FieldAccess(place, field, None) for field in place.ty.fields] - child_ports = [self[child] for child in children] - port = self.graph.add_make_tuple(child_ports, self.node).out_port(0) + child_types = [child.ty.to_hugr() for child in children] + child_wires = [self[child] for child in children] + wire = self.builder.add_op(ops.MakeTuple(child_types), *child_wires)[0] for child in children: if child.ty.linear: self.locals.pop(child.id) - self.locals[place.id] = port - return port + self.locals[place.id] = wire + return wire - def __setitem__(self, place: Place, port: OutPortV) -> None: + def __setitem__(self, place: Place, port: Wire) -> None: # When assigning a struct value, we immediately unpack it recursively and only # store the leaf wires. is_return = isinstance(place, Variable) and is_return_var(place.name) if isinstance(place.ty, StructType) and not is_return: - unpack = self.graph.add_unpack_tuple(port, self.node) - for field, field_port in zip( - place.ty.fields, unpack.out_ports, strict=True - ): + unpack = self.builder.add_op( + ops.UnpackTuple([t.ty.to_hugr() for t in place.ty.fields]), port + ) + for field, field_port in zip(place.ty.fields, unpack, strict=True): self[FieldAccess(place, field, None)] = field_port # If we had a previous wire assigned to this place, we need forget about it. # Otherwise, we might use this old value when looking up the place later @@ -66,17 +78,15 @@ def __setitem__(self, place: Place, port: OutPortV) -> None: def __copy__(self) -> "DFContainer": # Make a copy of the var map so that mutating the copy doesn't # mutate our variable mapping - return DFContainer(self.graph, self.node, self.locals.copy()) + return DFContainer(self.builder, self.locals.copy()) class CompilerBase(ABC): """Base class for the Guppy compiler.""" - graph: Hugr globals: CompiledGlobals - def __init__(self, graph: Hugr, globals: CompiledGlobals) -> None: - self.graph = graph + def __init__(self, globals: CompiledGlobals) -> None: self.globals = globals diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index f5935e21..52c1dc9c 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -1,10 +1,18 @@ import ast import json -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from contextlib import contextmanager from typing import Any, TypeGuard, TypeVar -from hugr.serialization import ops, tys +import hugr +import hugr.std.float +import hugr.std.int +import hugr.std.logic +from hugr import Wire, ops +from hugr import tys as ht +from hugr import val as hv +from hugr.cond_loop import Conditional +from hugr.dfg import DP, _DfBase from typing_extensions import assert_never from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type @@ -13,13 +21,6 @@ from guppylang.compiler.core import CompilerBase, DFContainer from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.error import GuppyError, InternalGuppyError -from guppylang.hugr_builder.hugr import ( - UNDEFINED, - DFContainingNode, - DummyOp, - OutPortV, - VNode, -) from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, @@ -33,7 +34,6 @@ TypeApply, ) from guppylang.tys.builtin import ( - bool_type, get_element_type, is_bool_type, is_list_type, @@ -51,20 +51,18 @@ ) -class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): +class ExprCompiler(CompilerBase, AstVisitor[Wire]): """A compiler from guppylang expressions to Hugr.""" dfg: DFContainer - def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: - """Compiles an expression and returns a single port holding the output value.""" + def compile(self, expr: ast.expr, dfg: DFContainer) -> Wire: + """Compiles an expression and returns a single wire holding the output value.""" self.dfg = dfg - with self.graph.parent(dfg.node): - res = self.visit(expr) - return res + return self.visit(expr) - def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: - """Compiles a row expression and returns a list of ports, one for each value in + def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[Wire]: + """Compiles a row expression and returns a list of wires, one for each value in the row. On Python-level, we treat tuples like rows on top-level. However, nested tuples @@ -72,25 +70,31 @@ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: """ return [self.compile(e, dfg) for e in expr_to_row(expr)] + @property + def builder(self) -> _DfBase[ops.DfParentOp]: + """The current Hugr dataflow graph builder.""" + return self.dfg.builder + @contextmanager def _new_dfcontainer( - self, inputs: list[PlaceNode], node: DFContainingNode + self, inputs: list[PlaceNode], builder: _DfBase[DP] ) -> Iterator[None]: """Context manager to build a graph inside a new `DFContainer`. Automatically updates `self.dfg` and makes the inputs available. """ old = self.dfg - inp = self.graph.add_input(parent=node) # Check that the input names are unique assert len({inp.place.id for inp in inputs}) == len( inputs ), "Inputs are not unique" - self.dfg = DFContainer(self.graph, node, self.dfg.locals.copy()) - for input_node in inputs: - self.dfg[input_node.place] = inp.add_out_port(input_node.place.ty) - with self.graph.parent(node): - yield + self.dfg = DFContainer(builder, self.dfg.locals.copy()) + hugr_input = builder.input_node + for input_node, wire in zip(inputs, hugr_input, strict=True): + self.dfg[input_node.place] = wire + + yield + self.dfg = old @contextmanager @@ -98,39 +102,44 @@ def _new_loop( self, loop_vars: list[PlaceNode], branch: PlaceNode, - parent: DFContainingNode | None = None, ) -> Iterator[None]: """Context manager to build a graph inside a new `TailLoop` node. Automatically adds the `Output` node to the loop body once the context manager exits. """ - loop = self.graph.add_tail_loop( - [self.visit(name) for name in loop_vars], parent - ) + loop_inputs = [self.visit(name) for name in loop_vars] + loop = self.builder.add_tail_loop([], loop_inputs) with self._new_dfcontainer(loop_vars, loop): yield # Output the branch predicate and the inputs for the next iteration - self.graph.add_output( + loop.set_loop_outputs( # Note that we have to do fresh calls to `self.visit` here since we're # in a new context - [self.visit(branch), *(self.visit(name) for name in loop_vars)] + self.visit(branch), + *(self.visit(name) for name in loop_vars), ) # Update the DFG with the outputs from the loop - for node in loop_vars: - self.dfg[node.place] = loop.add_out_port(node.place.ty) + for node, wire in zip(loop_vars, loop, strict=True): + self.dfg[node.place] = wire @contextmanager def _new_case( - self, inputs: list[PlaceNode], outputs: list[PlaceNode], cond_node: VNode + self, + inputs: list[PlaceNode], + outputs: list[PlaceNode], + conditional: Conditional, + case_id: int, ) -> Iterator[None]: """Context manager to build a graph inside a new `Case` node. Automatically adds the `Output` node once the context manager exits. """ - with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): + # TODO: `Case` is `_DfgBase`, but not `Dfg`? + case = conditional.add_case(case_id) + with self._new_dfcontainer(inputs, case): yield - self.graph.add_output([self.visit(name) for name in outputs]) + case.set_outputs(*(self.visit(name) for name in outputs)) @contextmanager def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]: @@ -138,29 +147,28 @@ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]: In the `false` case, the inputs are outputted as is. """ - cond_node = self.graph.add_conditional( - self.visit(cond), [self.visit(inp) for inp in inputs] + conditional = self.builder.add_conditional( + self.visit(cond), *(self.visit(inp) for inp in inputs) ) - # If the condition is false, output the inputs as is - with self._new_case(inputs, inputs, cond_node): - pass # If the condition is true, we enter the `with` block - with self._new_case(inputs, inputs, cond_node): + with self._new_case(inputs, inputs, conditional, 0): yield + # If the condition is false, output the inputs as is + with self._new_case(inputs, inputs, conditional, 1): + pass # Update the DFG with the outputs from the Conditional node - for node in inputs: - self.dfg[node.place] = cond_node.add_out_port(node.place.ty) + for node, wire in zip(inputs, conditional, strict=True): + self.dfg[node.place] = wire - def visit_Constant(self, node: ast.Constant) -> OutPortV: + def visit_Constant(self, node: ast.Constant) -> Wire: if value := python_value_to_hugr(node.value, get_type(node)): - const = self.graph.add_constant(value, get_type(node)).out_port(0) - return self.graph.add_load_constant(const).out_port(0) + return self.builder.load(value) raise InternalGuppyError("Unsupported constant expression in compiler") - def visit_PlaceNode(self, node: PlaceNode) -> OutPortV: + def visit_PlaceNode(self, node: PlaceNode) -> Wire: return self.dfg[node.place] - def visit_GlobalName(self, node: GlobalName) -> OutPortV: + def visit_GlobalName(self, node: GlobalName) -> Wire: defn = self.globals[node.def_id] assert isinstance(defn, CompiledValueDef) if isinstance(defn, CompiledCallableDef) and defn.ty.parametrized: @@ -169,100 +177,130 @@ def visit_GlobalName(self, node: GlobalName) -> OutPortV: "supported yet", node, ) - return defn.load(self.dfg, self.graph, self.globals, node) + return defn.load(self.dfg, self.globals, node) - def visit_Name(self, node: ast.Name) -> OutPortV: + def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") - def visit_Tuple(self, node: ast.Tuple) -> OutPortV: - return self.graph.add_make_tuple( - inputs=[self.visit(e) for e in node.elts] - ).out_port(0) + def visit_Tuple(self, node: ast.Tuple) -> Wire: + elems = [self.visit(e) for e in node.elts] + types = [get_type(e) for e in node.elts] + return self._pack_tuple(elems, types) - def visit_List(self, node: ast.List) -> OutPortV: + def visit_List(self, node: ast.List) -> Wire: # Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension - return self.graph.add_node( - DummyOp("MakeList"), inputs=[self.visit(e) for e in node.elts] - ).add_out_port(get_type(node)) + inputs = [self.visit(e) for e in node.elts] + in_types = [get_type(e) for e in node.elts] + out_type = get_type(node) + return self.builder.add_op( + make_list_op(in_types, out_type), + *inputs, + ) - def _unpack_tuple(self, wire: OutPortV) -> list[OutPortV]: - unpack_node = self.graph.add_unpack_tuple(wire, self.dfg.node) - return list(unpack_node.out_ports) + def _unpack_tuple(self, wire: Wire, types: Sequence[Type]) -> Sequence[Wire]: + """Add a tuple unpack operation to the graph""" + types = [t.to_hugr() for t in types] + return list(self.builder.add_op(ops.UnpackTuple(types), wire)) - def _pack_returns(self, returns: list[OutPortV], return_ty: Type) -> OutPortV: + def _pack_tuple(self, wires: Sequence[Wire], types: Sequence[Type]) -> Wire: + """Add a tuple pack operation to the graph""" + types = [t.to_hugr() for t in types] + return self.builder.add_op(ops.MakeTuple(types), *wires) + + def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire: """Groups function return values into a tuple""" if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve: - assert len(returns) == ( - len(return_ty.element_types) if isinstance(return_ty, TupleType) else 0 - ) - return self.graph.add_make_tuple(inputs=returns).out_port(0) - assert len(returns) == 1 + types = type_to_row(return_ty) + assert len(returns) == len(types) + return self._pack_tuple(returns, types) + assert len(returns) == 1, ( + f"Expected a single return value. Got {returns}. " + f"return type {return_ty}" + ) return returns[0] - def visit_LocalCall(self, node: LocalCall) -> OutPortV: + def visit_LocalCall(self, node: LocalCall) -> Wire: func = self.visit(node.func) - assert isinstance(func.ty, FunctionType) + func_ty = get_type(node.func) + assert isinstance(func_ty, FunctionType) args = [self.visit(arg) for arg in node.args] - call = self.graph.add_indirect_call(func, args) - rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.output)))] - return self._pack_returns(rets, func.ty.output) + call = self.builder.add_op(ops.CallIndirect(func_ty.to_hugr()), func, *args) + return self._pack_returns(list(call), func_ty.output) - def visit_TensorCall(self, node: TensorCall) -> OutPortV: - func = self.visit(node.func) - args = [self.visit(arg) for arg in node.args] + def visit_TensorCall(self, node: TensorCall) -> Wire: + functions: Wire = self.visit(node.func) + function_types = get_type(node.func) - assert isinstance(func.ty, TupleType) + args = [self.visit(arg) for arg in node.args] + assert isinstance(function_types, TupleType) - rets: list[OutPortV] = [] + rets: list[Wire] = [] remaining_args = args - for elem in self._unpack_tuple(func): + for func, func_ty in zip( + self._unpack_tuple(functions, function_types.element_types), + function_types.element_types, + strict=True, + ): outs, remaining_args = self._compile_tensor_with_leftovers( - elem, remaining_args + func, func_ty, remaining_args ) rets.extend(outs) - assert remaining_args == [] + assert ( + remaining_args == [] + ), "Not all function arguments were consumed after a tensor call" return self._pack_returns(rets, node.out_tys) def _compile_tensor_with_leftovers( - self, func: OutPortV, args: list[OutPortV] + self, func: Wire, func_ty: Type, args: list[Wire] ) -> tuple[ - list[OutPortV], # Compiled outputs - list[OutPortV], - ]: # Leftover args - if isinstance(func.ty, TupleType): + list[Wire], # Compiled outputs + list[Wire], # Leftover args + ]: + """Compiles a function call, consuming as many arguments as needed, and + returning the unused ones. + """ + if isinstance(func_ty, TupleType): remaining_args = args all_outs = [] - for elem in self._unpack_tuple(func): + for elem, ty in zip( + self._unpack_tuple(func, func_ty.element_types), + func_ty.element_types, + strict=True, + ): outs, remaining_args = self._compile_tensor_with_leftovers( - elem, remaining_args + elem, ty, remaining_args ) all_outs.extend(outs) return all_outs, remaining_args - elif isinstance(func.ty, FunctionType): - input_len = len(func.ty.inputs) - call = self.graph.add_indirect_call(func, args[0:input_len]) + elif isinstance(func_ty, FunctionType): + input_len = len(func_ty.inputs) + consumed_args, other_args = args[0:input_len], args[input_len:] - return list(call.out_ports), args[input_len:] + call = self.builder.add_op( + ops.CallIndirect(func_ty.to_hugr()), func, *consumed_args + ) + + return list(call), other_args else: raise InternalGuppyError("Tensor element wasn't function or tuple") - def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: + def visit_GlobalCall(self, node: GlobalCall) -> Wire: func = self.globals[node.def_id] assert isinstance(func, CompiledCallableDef) args = [self.visit(arg) for arg in node.args] rets = func.compile_call( - args, list(node.type_args), self.dfg, self.graph, self.globals, node + args, list(node.type_args), self.dfg, self.globals, node ) return self._pack_returns(rets, func.ty.output) - def visit_Call(self, node: ast.Call) -> OutPortV: + def visit_Call(self, node: ast.Call) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") - def visit_TypeApply(self, node: TypeApply) -> OutPortV: + def visit_TypeApply(self, node: TypeApply) -> Wire: # For now, we can only TypeApply global FunctionDefs/Decls. if not isinstance(node.value, GlobalName): raise InternalGuppyError("Dynamic TypeApply not supported yet!") @@ -282,39 +320,34 @@ def visit_TypeApply(self, node: TypeApply) -> OutPortV: node, ) - return defn.load_with_args(node.inst, self.dfg, self.graph, self.globals, node) + return defn.load_with_args(node.inst, self.dfg, self.globals, node) - def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: + def visit_UnaryOp(self, node: ast.UnaryOp) -> Wire: # The only case that is not desugared by the type checker is the `not` operation # since it is not implemented via a dunder method if isinstance(node.op, ast.Not): arg = self.visit(node.operand) - op = ops.CustomOp(extension="logic", name="Not", args=[], parent=UNDEFINED) - return self.graph.add_node(ops.OpType(op), inputs=[arg]).add_out_port( - bool_type() - ) + return self.builder.add_op(hugr.std.logic.Not, arg) raise InternalGuppyError("Node should have been removed during type checking.") - def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> OutPortV: + def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> Wire: struct_port = self.visit(node.value) - unpack = self.graph.add_unpack_tuple(struct_port) - return unpack.out_port(node.struct_ty.fields.index(node.field)) + field_idx = node.struct_ty.fields.index(node.field) + return self._unpack_tuple(struct_port, [f.ty for f in node.struct_ty.fields])[ + field_idx + ] - def visit_ResultExpr(self, node: ResultExpr) -> OutPortV: + def visit_ResultExpr(self, node: ResultExpr) -> Wire: extra_args = [] if isinstance(node.base_ty, NumericType): match node.base_ty.kind: case NumericType.Kind.Nat: base_name = "uint" - extra_args = [ - tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH)) - ] + extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)] case NumericType.Kind.Int: base_name = "int" - extra_args = [ - tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH)) - ] + extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)] case NumericType.Kind.Float: base_name = "f64" case kind: @@ -328,7 +361,7 @@ def visit_ResultExpr(self, node: ResultExpr) -> OutPortV: match node.array_len: case ConstValue(value=value): assert isinstance(value, int) - extra_args = [tys.TypeArg(tys.BoundedNatArg(n=value)), *extra_args] + extra_args = [ht.BoundedNatArg(n=value), *extra_args] case BoundConstVar(): # TODO: We need to handle this once we allow function definitions # that are generic over array lengths @@ -341,37 +374,44 @@ def visit_ResultExpr(self, node: ResultExpr) -> OutPortV: assert_never(c) else: op_name = f"result_{base_name}" - args = [tys.TypeArg(tys.StringArg(arg=node.tag)), *extra_args] - op = ops.CustomOp( + args = [ + ht.StringArg(node.tag), + *extra_args, + ] + sig = ht.FunctionType( + input=[get_type(node.value).to_hugr()], + output=[], + ) + op = ops.Custom( extension="tket2.result", name=op_name, args=args, - parent=UNDEFINED, + signature=sig, ) - self.graph.add_node(ops.OpType(op), inputs=[self.visit(node.value)]) + self.builder.add_op(op, self.visit(node.value)) return self._pack_returns([], NoneType()) - def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV: + def visit_DesugaredListComp(self, node: DesugaredListComp) -> Wire: from guppylang.compiler.stmt_compiler import StmtCompiler - compiler = StmtCompiler(self.graph, self.globals) + compiler = StmtCompiler(self.globals) # Make up a name for the list under construction and bind it to an empty list list_ty = get_type(node) list_place = Variable(next(tmp_vars), list_ty, node) list_name = with_type(list_ty, with_loc(node, PlaceNode(place=list_place))) - empty_list = self.graph.add_node(DummyOp("MakeList")) - self.dfg[list_place] = empty_list.add_out_port(list_ty) + self.dfg[list_place] = self.builder.add_op(make_list_op([], list_ty)) def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: """Helper function to generate nested TailLoop nodes for generators""" # If there are no more generators left, just append the element to the list if not gens: list_port, elt_port = self.visit(list_name), self.visit(elt) - push = self.graph.add_node( - DummyOp("Push"), inputs=[list_port, elt_port] + elt_ty = get_type(elt) + push = self.builder.add_op( + list_push_op(list_ty, elt_ty), list_port, elt_port ) - self.dfg[list_place] = push.add_out_port(list_port.ty) + self.dfg[list_place] = push return # Otherwise, compile the first iterator and construct a TailLoop @@ -407,10 +447,10 @@ def compile_ifs(ifs: list[ast.expr]) -> None: compile_generators(node.elt, node.generators) return self.visit(list_name) - def visit_BinOp(self, node: ast.BinOp) -> OutPortV: + def visit_BinOp(self, node: ast.BinOp) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") - def visit_Compare(self, node: ast.Compare) -> OutPortV: + def visit_Compare(self, node: ast.Compare) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") @@ -427,25 +467,20 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: return False -def python_value_to_hugr(v: Any, exp_ty: Type) -> ops.Value | None: +def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: """Turns a Python value into a Hugr value. Returns None if the Python value cannot be represented in Guppy. """ - from guppylang.prelude._internal import ( - bool_value, - float_value, - int_value, - list_value, - ) + from guppylang.prelude._internal.util import ListVal match v: case bool(): - return bool_value(v) + return hv.bool_value(v) case int(): - return int_value(v) + return hugr.std.int.IntVal(v, width=NumericType.INT_WIDTH) case float(): - return float_value(v) + return hugr.std.float.FloatVal(v) case tuple(elts): assert isinstance(exp_ty, TupleType) vs = [ @@ -453,12 +488,12 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> ops.Value | None: for elt, ty in zip(elts, exp_ty.element_types, strict=True) ] if doesnt_contain_none(vs): - return ops.Value(ops.TupleValue(vs=vs)) + return hv.Tuple(*vs) case list(elts): assert is_list_type(exp_ty) vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts] if doesnt_contain_none(vs): - return list_value(vs, get_element_type(exp_ty)) + return ListVal(vs, get_element_type(exp_ty)) case _: # Pytket conversion is an optional feature try: @@ -469,13 +504,34 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> ops.Value | None: Tk2Circuit, ) - hugr = json.loads(Tk2Circuit(v).to_hugr_json()) # type: ignore[attr-defined, unused-ignore] - return ops.Value(ops.FunctionValue(hugr=hugr)) + circ = json.loads(Tk2Circuit(v).to_hugr_json()) # type: ignore[attr-defined, unused-ignore] + return hv.Function(circ) except ImportError: pass return None +def make_dummy_op( + name: str, inp: Sequence[Type], out: Sequence[Type] +) -> ops.DataflowOp: + """Dummy operation.""" + input = [ty.to_hugr() for ty in inp] + output = [ty.to_hugr() for ty in out] + + sig = ht.FunctionType(input=input, output=output) + return ops.Custom(name=name, extension="dummy", signature=sig, args=[]) + + +def make_list_op(in_types: Sequence[Type], out_type: Type) -> ops.DataflowOp: + """Creates a dummy operation for constructing a list.""" + return make_dummy_op("MakeList", in_types, [out_type]) + + +def list_push_op(list_ty: Type, elem_ty: Type) -> ops.DataflowOp: + """Creates a dummy operation for constructing a list.""" + return make_dummy_op("Push", [list_ty, elem_ty], [list_ty]) + + T = TypeVar("T") diff --git a/guppylang/compiler/func_compiler.py b/guppylang/compiler/func_compiler.py index 417116c8..94b6ff7c 100644 --- a/guppylang/compiler/func_compiler.py +++ b/guppylang/compiler/func_compiler.py @@ -1,10 +1,14 @@ +from collections.abc import Sequence from typing import TYPE_CHECKING +from hugr import Wire, ops +from hugr import tys as ht +from hugr.function import Function + from guppylang.compiler.cfg_compiler import compile_cfg from guppylang.compiler.core import CompiledGlobals, DFContainer -from guppylang.hugr_builder.hugr import DFContainingVNode, Hugr, OutPortV from guppylang.nodes import CheckedNestedFunctionDef -from guppylang.tys.ty import FunctionType, type_to_row +from guppylang.tys.ty import FunctionType, Type if TYPE_CHECKING: from guppylang.definition.function import CheckedFunctionDef @@ -12,55 +16,58 @@ def compile_global_func_def( func: "CheckedFunctionDef", - def_node: DFContainingVNode, - graph: Hugr, + builder: Function, globals: CompiledGlobals, ) -> None: """Compiles a top-level function definition to Hugr.""" - _, ports = graph.add_input_with_ports(list(func.ty.inputs), def_node) - cfg_node = graph.add_cfg(def_node, ports) - compile_cfg(func.cfg, graph, cfg_node, globals) - - # Add output node for the cfg - graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], - parent=def_node, - ) + cfg = compile_cfg(func.cfg, builder, builder.inputs(), globals) + + builder.set_outputs(*cfg) def compile_local_func_def( func: CheckedNestedFunctionDef, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, -) -> OutPortV: - """Compiles a local (nested) function definition to Hugr.""" +) -> Wire: + """Compiles a local (nested) function definition to Hugr and loads it into a value. + + Returns the wire output of the `LoadFunc` operation. + """ assert func.ty.input_names is not None # Pick an order for the captured variables captured = list(func.captured.values()) + captured_types = [v.ty for v, _ in captured] + + # Whether the function calls itself recursively. + recursive = func.name in func.cfg.live_before[func.cfg.entry_bb] # Prepend captured variables to the function arguments closure_ty = FunctionType( - [v.ty for v, _ in captured] + list(func.ty.inputs), + captured_types + list(func.ty.inputs), func.ty.output, - [v.name for v, _ in captured] + list(func.ty.input_names), + input_names=[v.name for v, _ in captured] + list(func.ty.input_names), ) + hugr_closure_ty: ht.FunctionType = closure_ty.to_hugr() - def_node = graph.add_def(closure_ty, dfg.node, func.name) - def_input, input_ports = graph.add_input_with_ports( - list(closure_ty.inputs), def_node + func_builder = dfg.builder.define_function( + func.name, hugr_closure_ty.input, hugr_closure_ty.output ) # If we have captured variables and the body contains a recursive occurrence of # the function itself, then we provide the partially applied function as a local # variable - if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: - loaded = graph.add_load_function(def_node.out_port(0), [], def_node).out_port(0) - partial = graph.add_partial( - loaded, [def_input.out_port(i) for i in range(len(captured))], def_node + call_args: list[Wire] = list(func_builder.inputs()) + if len(captured) > 0 and recursive: + loaded = func_builder.load_function(func_builder, hugr_closure_ty) + partial = func_builder.add_op( + make_partial_op(closure_ty, captured_types), + loaded, + *func_builder.input_node[: len(captured)], ) - input_ports.append(partial.out_port(0)) + + call_args.append(partial) func.cfg.input_tys.append(func.ty) else: # Otherwise, we treat the function like a normal global variable @@ -75,25 +82,62 @@ def compile_local_func_def( {}, None, func.cfg, - def_node, + func_builder, ) } # Compile the CFG - cfg_node = graph.add_cfg(def_node, inputs=input_ports) - compile_cfg(func.cfg, graph, cfg_node, globals) - - # Add output node for the cfg - graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], - parent=def_node, - ) + cfg = compile_cfg(func.cfg, func_builder, call_args, globals) + func_builder.set_outputs(*cfg) # Finally, load the function into the local data-flow graph - loaded = graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0) + loaded = dfg.builder.load_function(func_builder, hugr_closure_ty) if len(captured) > 0: - loaded = graph.add_partial( - loaded, [dfg[v] for v, _ in captured], dfg.node - ).out_port(0) + loaded = dfg.builder.add_op( + make_partial_op(closure_ty, captured_types), + loaded, + *(dfg[v] for v, _ in captured), + ) return loaded + + +def make_dummy_op( + name: str, inp: Sequence[Type], out: Sequence[Type], extension: str = "dummy" +) -> ops.DataflowOp: + """Dummy operation.""" + input = [ty.to_hugr() for ty in inp] + output = [ty.to_hugr() for ty in out] + + sig = ht.FunctionType(input=input, output=output) + return ops.Custom(name=name, extension=extension, signature=sig, args=[]) + + +def make_partial_op( + closure_ty: FunctionType, captured_tys: Sequence[Type] +) -> ops.DataflowOp: + """Creates a dummy operation for partially evaluating a function. + + args: + closure_ty: A function type `(c_0, ..., c_k, a_0, ..., a_n) -> b_0, ..., b_m` + captured_tys: A list of types `c_0, ..., c_k` that are captured by the function + + returns: + An operation with type + ` (c_0, ..., c_k, a_0, ..., a_n -> b_0, ..., b_m ), c_0, ..., c_k` + `-> (a_0, ..., a_n -> b_0, ..., b_m)` + """ + assert len(closure_ty.inputs) >= len(captured_tys) + assert [p.to_hugr() for p in captured_tys] == [ + ty.to_hugr() for ty in closure_ty.inputs[: len(captured_tys)] + ] + + explicit_inputs = closure_ty.inputs[len(captured_tys) :] + partially_applied_func = FunctionType(explicit_inputs, closure_ty.output) + + return make_dummy_op( + "partial", + [closure_ty, *captured_tys], + [partially_applied_func], + extension="guppylang.unsupported", + ) diff --git a/guppylang/compiler/stmt_compiler.py b/guppylang/compiler/stmt_compiler.py index aa9e2689..cdd7c803 100644 --- a/guppylang/compiler/stmt_compiler.py +++ b/guppylang/compiler/stmt_compiler.py @@ -1,6 +1,9 @@ import ast from collections.abc import Sequence +from hugr import Wire, ops +from hugr.dfg import _DfBase + from guppylang.ast_util import AstVisitor, get_type from guppylang.checker.core import Variable from guppylang.compiler.core import ( @@ -11,9 +14,8 @@ ) from guppylang.compiler.expr_compiler import ExprCompiler from guppylang.error import InternalGuppyError -from guppylang.hugr_builder.hugr import Hugr, OutPortV from guppylang.nodes import CheckedNestedFunctionDef, PlaceNode -from guppylang.tys.ty import TupleType +from guppylang.tys.ty import TupleType, Type class StmtCompiler(CompilerBase, AstVisitor[None]): @@ -23,9 +25,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]): dfg: DFContainer - def __init__(self, graph: Hugr, globals: CompiledGlobals): - super().__init__(graph, globals) - self.expr_compiler = ExprCompiler(graph, globals) + def __init__(self, globals: CompiledGlobals): + super().__init__(globals) + self.expr_compiler = ExprCompiler(globals) def compile_stmts( self, @@ -42,14 +44,20 @@ def compile_stmts( self.visit(s) return self.dfg - def _unpack_assign(self, lhs: ast.expr, port: OutPortV, node: ast.stmt) -> None: + @property + def builder(self) -> _DfBase[ops.DfParentOp]: + """The Hugr dataflow graph builder.""" + return self.dfg.builder + + def _unpack_assign(self, lhs: ast.expr, port: Wire, node: ast.stmt) -> None: """Updates the local DFG with assignments.""" if isinstance(lhs, PlaceNode): self.dfg[lhs.place] = port elif isinstance(lhs, ast.Tuple): - unpack = self.graph.add_unpack_tuple(port, self.dfg.node) - for i, pat in enumerate(lhs.elts): - self._unpack_assign(pat, unpack.out_port(i), node) + types = [get_type(e).to_hugr() for e in lhs.elts] + unpack = self.builder.add_op(ops.UnpackTuple(types), port) + for pat, wire in zip(lhs.elts, unpack, strict=True): + self._unpack_assign(pat, wire, node) else: raise InternalGuppyError("Invalid assign pattern in compiler") @@ -75,17 +83,22 @@ def visit_Return(self, node: ast.Return) -> None: if node.value is not None: return_ty = get_type(node.value) port = self.expr_compiler.compile(node.value, self.dfg) + + row: list[tuple[Wire, Type]] if isinstance(return_ty, TupleType): - unpack = self.graph.add_unpack_tuple(port, self.dfg.node) - row = [unpack.out_port(i) for i in range(len(return_ty.element_types))] + types = [e.to_hugr() for e in return_ty.element_types] + unpack = self.builder.add_op(ops.UnpackTuple(types), port) + row = list(zip(unpack, return_ty.element_types, strict=True)) else: - row = [port] - for i, port in enumerate(row): - var = Variable(return_var(i), port.ty, node.value) - self.dfg[var] = port + row = [(port, return_ty)] + + for i, (wire, ty) in enumerate(row): + var = Variable(return_var(i), ty, node.value) + self.dfg[var] = wire def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: from guppylang.compiler.func_compiler import compile_local_func_def var = Variable(node.name, node.ty, node) - self.dfg[var] = compile_local_func_def(node, self.dfg, self.graph, self.globals) + loaded_func = compile_local_func_def(node, self.dfg, self.globals) + self.dfg[var] = loaded_func diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 442ba6fb..8216c803 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -6,7 +6,8 @@ from types import ModuleType from typing import Any, TypeVar -from hugr.serialization import ops, tys +from hugr import Hugr, ops +from hugr import tys as ht from guppylang.ast_util import annotate_location, has_empty_body from guppylang.definition.common import DefId @@ -25,8 +26,8 @@ from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError, MissingModuleError, pretty_errors -from guppylang.hugr_builder.hugr import Hugr from guppylang.module import GuppyModule, PyFunc +from guppylang.tys.subst import Inst from guppylang.tys.ty import NumericType FuncDefDecorator = Callable[[PyFunc], RawFunctionDef] @@ -122,10 +123,10 @@ def dec(c: type) -> type: def type( self, module: GuppyModule, - hugr_ty: tys.Type, + hugr_ty: ht.Type, name: str = "", linear: bool = False, - bound: tys.TypeBound | None = None, + bound: ht.TypeBound | None = None, ) -> ClassDecorator: """Decorator to annotate a class definitions as Guppy types. @@ -222,12 +223,22 @@ def dec(f: PyFunc) -> RawCustomFunctionDef: def hugr_op( self, module: GuppyModule, - op: ops.OpType, + op: Callable[[ht.FunctionType, Inst], ops.DataflowOp], checker: CustomCallChecker | None = None, higher_order_value: bool = True, name: str = "", ) -> CustomFuncDecorator: - """Decorator to annotate function declarations as HUGR ops.""" + """Decorator to annotate function declarations as HUGR ops. + + Args: + module: The module in which the function should be defined. + op: A function that takes an instantiation of the type arguments as well as + the inferred input and output types and returns a concrete HUGR op. + checker: The custom call checker. + higher_order_value: Whether the function may be used as a higher-order + value. + name: The name of the function. + """ return self.custom(module, OpCompiler(op), checker, higher_order_value, name) def declare(self, module: GuppyModule) -> FuncDeclDecorator: @@ -296,7 +307,7 @@ def take_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: raise MissingModuleError(err) return self._modules.pop(id) - def compile_module(self, id: ModuleIdentifier | None = None) -> Hugr | None: + def compile_module(self, id: ModuleIdentifier | None = None) -> Hugr[ops.Module]: """Compiles the local module into a Hugr.""" module = self.take_module(id) if not module: diff --git a/guppylang/definition/common.py b/guppylang/definition/common.py index 9a5cf561..61bd073e 100644 --- a/guppylang/definition/common.py +++ b/guppylang/definition/common.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, ClassVar, TypeAlias -from guppylang.hugr_builder.hugr import Hugr, Node +from hugr.dfg import OpVar, _DefinitionBuilder if TYPE_CHECKING: from guppylang.checker.core import Globals @@ -24,6 +24,10 @@ class DefId: This id is persistent across all compilation stages. It can be used to identify a definition at any step in the compilation pipeline. + + Args: + id: An integer uniquely identifying the definition. + module: The module where the definition was defined. """ id: int @@ -43,6 +47,11 @@ class Definition(ABC): Each definition is identified by a globally unique id. Furthermore, we store the user-picked name for the defined object and an optional AST node for the definition location. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. """ id: DefId @@ -65,6 +74,11 @@ class ParsableDef(Definition): For example, raw function definitions first need to parse their signature and check that all types are valid. The result of parsing should be a definition that is ready to be checked. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. """ @abstractmethod @@ -79,6 +93,11 @@ class CheckableDef(Definition): """Abstract base class for definitions that still need to be checked. The result of checking should be a definition that is ready to be compiled to Hugr. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. """ @abstractmethod @@ -97,11 +116,16 @@ class CompilableDef(Definition): The result of compilation should be a `CompiledDef` with a pointer to the Hugr node that was created for this definition. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. """ @abstractmethod - def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledDef": - """Adds a Hugr node for the definition to the provided graph. + def compile_outer(self, module: _DefinitionBuilder[OpVar]) -> "CompiledDef": + """Adds a Hugr node for the definition to the provided Hugr module. Note that is not required to fill in the contents of the node. At this point, we don't have access to the globals since they have not all been compiled yet. @@ -112,9 +136,15 @@ def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledDef": class CompiledDef(Definition): - """Abstract base class for definitions that have been added to a Hugr.""" + """Abstract base class for definitions that have been added to a Hugr. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. + """ - def compile_inner(self, graph: Hugr, globals: "CompiledGlobals") -> None: + def compile_inner(self, globals: "CompiledGlobals") -> None: """Optional hook that is called to fill in the content of the Hugr node. Opposed to `CompilableDef.compile()`, we have access to all other compiled diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index 83a452d6..57918cad 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -1,10 +1,13 @@ import ast from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from hugr.serialization import ops +from hugr import Wire, ops +from hugr import tys as ht +from hugr.dfg import _DfBase -from guppylang.ast_util import AstNode, get_type, with_loc, with_type +from guppylang.ast_util import AstNode, with_loc, with_type from guppylang.checker.core import Context, Globals from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import check_signature @@ -12,10 +15,9 @@ from guppylang.definition.common import ParsableDef from guppylang.definition.value import CompiledCallableDef from guppylang.error import GuppyError, InternalGuppyError -from guppylang.hugr_builder.hugr import Hugr, OutPortV from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst -from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row +from guppylang.tys.ty import FunctionType, NoneType, Type @dataclass(frozen=True) @@ -27,6 +29,14 @@ class RawCustomFunctionDef(ParsableDef): The raw definition stores exactly what the user has written (i.e. the AST together with the provided checker and compiler), without inspecting the signature. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. + call_checker: The custom call checker. + call_compiler: The custom call compiler. + higher_order_value: Whether the function may be used as a higher-order value. """ defined_at: ast.FunctionDef @@ -51,27 +61,7 @@ def parse(self, globals: "Globals") -> "CustomFunctionDef": code. The only information we need to access is that it's a function type and that there are no unsolved existential vars. """ - # Type annotations are needed if we rely on the default call checker or want - # to allow the usage of the function as a higher-order value - requires_type_annotation = ( - isinstance(self.call_checker, DefaultCallChecker) or self.higher_order_value - ) - has_type_annotation = self.defined_at.returns or any( - arg.annotation for arg in self.defined_at.args.args - ) - - if requires_type_annotation and not has_type_annotation: - raise GuppyError( - f"Type signature for function `{self.name}` is required. " - "Alternatively, try passing `higher_order_value=False` on definition.", - self.defined_at, - ) - - ty = ( - check_signature(self.defined_at, globals) - if requires_type_annotation - else FunctionType([], NoneType()) - ) + ty = self._get_signature(globals) or FunctionType([], NoneType()) return CustomFunctionDef( self.id, self.name, @@ -84,26 +74,73 @@ def parse(self, globals: "Globals") -> "CustomFunctionDef": def compile_call( self, - args: list[OutPortV], + args: list[Wire], type_args: Inst, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, node: AstNode, - ) -> list[OutPortV]: + function_ty: ht.FunctionType, + ) -> Sequence[Wire]: """Compiles a call to the function.""" - self.call_compiler._setup(type_args, dfg, graph, globals, node) + # Note: We have _compiled_ globals rather than `Globals` here, + # so we cannot use `self._get_signature()`. + self.call_compiler._setup( + type_args, + dfg, + globals, + node, + function_ty, + ) return self.call_compiler.compile(args) + def _get_signature(self, globals: Globals) -> FunctionType | None: + """Returns the type of the function, if known. + + Type annotations are needed if we rely on the default call checker or + want to allow the usage of the function as a higher-order value. + + Some function types like python's `int()` cannot be expressed in the Guppy + type system, so we return `None` here and rely on the specialized compiler + to handle the call. + """ + requires_type_annotation = ( + isinstance(self.call_checker, DefaultCallChecker) or self.higher_order_value + ) + has_type_annotation = self.defined_at.returns or any( + arg.annotation for arg in self.defined_at.args.args + ) + + if requires_type_annotation and not has_type_annotation: + raise GuppyError( + f"Type signature for function `{self.name}` is required. " + "Alternatively, try passing `higher_order_value=False` on definition.", + self.defined_at, + ) + + if requires_type_annotation: + return check_signature(self.defined_at, globals) + else: + return None + @dataclass(frozen=True) class CustomFunctionDef(CompiledCallableDef): - """A custom function with parsed and checked signature.""" + """A custom function with parsed and checked signature. + + Args: + id: The unique definition identifier. + name: The name of the definition. + defined_at: The AST node where the definition was defined. + ty: The type of the function. + call_checker: The custom call checker. + call_compiler: The custom call compiler. + higher_order_value: Whether the function may be used as a higher-order value. + """ defined_at: AstNode + ty: FunctionType call_checker: "CustomCallChecker" call_compiler: "CustomCallCompiler" - ty: FunctionType higher_order_value: bool description: str = field(default="function", init=False) @@ -134,15 +171,14 @@ def load_with_args( self, type_args: Inst, dfg: "DFContainer", - graph: Hugr, globals: CompiledGlobals, node: AstNode, - ) -> OutPortV: + ) -> Wire: """Loads the custom function as a value into a local dataflow graph. - This will place a `FunctionDef` node into the Hugr module and loads it into the - DFG. This operation will fail the function is not allowed to be used as a - higher-order value. + This will place a `FunctionDef` node in the local DFG, and load with a + `LoadFunc` node. This operation will fail if the function is not allowed + to be used as a higher-order value. """ # TODO: This should be raised during checking, not compilation! if not self.higher_order_value: @@ -156,35 +192,40 @@ def load_with_args( # function, and returns the results. If the function signature is polymorphic, # we explicitly monomorphise here and invoke the call compiler with the # inferred type args. + # + # TODO: Reuse compiled instances with the same type args? + # TODO: Why do we need to monomorphise here? Why not wait for `load_function`? + # See https://github.com/CQCL/guppylang/issues/393 for both issues. fun_ty = self.ty.instantiate(type_args) - def_node = graph.add_def(fun_ty, dfg.node, self.name) - with graph.parent(def_node): - _, inp_ports = graph.add_input_with_ports(list(fun_ty.inputs)) - returns = self.compile_call( - inp_ports, - type_args, - DFContainer(graph, def_node), - graph, - globals, - node, - ) - graph.add_output(returns) + input_types = [ty.to_hugr() for ty in fun_ty.inputs] + output_types = [fun_ty.output.to_hugr()] + func = dfg.builder.define_function( + self.name, input_types, output_types, type_params=[] + ) + + func_dfg = DFContainer(func, dfg.locals.copy()) + args: list[Wire] = list(func.inputs()) + outputs = self.compile_call(args, type_args, func_dfg, globals, node) + + func.set_outputs(*outputs) # Finally, load the function into the local DFG. We already monomorphised, so we - # can load with empty type args - return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0) + # don't need to give it type arguments. + return dfg.builder.load_function(func) def compile_call( self, - args: list[OutPortV], + args: list[Wire], type_args: Inst, - dfg: DFContainer, - graph: Hugr, + dfg: "DFContainer", globals: CompiledGlobals, node: AstNode, - ) -> list[OutPortV]: + ) -> list[Wire]: """Compiles a call to the function.""" - self.call_compiler._setup(type_args, dfg, graph, globals, node) + concrete_ty = self.ty.instantiate(type_args) + hugr_ty = concrete_ty.to_hugr() + + self.call_compiler._setup(type_args, dfg, globals, node, hugr_ty) return self.call_compiler.compile(args) @@ -216,31 +257,47 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: class CustomCallCompiler(ABC): - """Abstract base class for custom function call compilers.""" + """Abstract base class for custom function call compilers. + + Args: + builder: The function builder where the function should be defined. + type_args: The type arguments for the function. + globals: The compiled globals. + node: The AST node where the function is defined. + ty: The type of the function, if known. + """ - type_args: Inst dfg: DFContainer - graph: Hugr + type_args: Inst globals: CompiledGlobals node: AstNode + ty: ht.FunctionType def _setup( self, type_args: Inst, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, node: AstNode, + hugr_ty: ht.FunctionType, ) -> None: self.type_args = type_args self.dfg = dfg - self.graph = graph self.globals = globals self.node = node + self.ty = hugr_ty @abstractmethod - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - """Compiles a custom function call and returns the resulting ports.""" + def compile(self, args: list[Wire]) -> list[Wire]: + """Compiles a custom function call and returns the resulting ports. + + Use the provided `self.builder` to add nodes to the Hugr graph. + """ + + @property + def builder(self) -> _DfBase[ops.DfParentOp]: + """The hugr dataflow builder.""" + return self.dfg.builder class DefaultCallChecker(CustomCallChecker): @@ -265,28 +322,31 @@ class NotImplementedCallCompiler(CustomCallCompiler): thus doesn't need to be compiled. """ - def compile(self, args: list[OutPortV]) -> list[OutPortV]: + def compile(self, args: list[Wire]) -> list[Wire]: raise InternalGuppyError("Function should have been removed during checking") class OpCompiler(CustomCallCompiler): - """Call compiler for functions that are directly implemented via Hugr ops.""" + """Call compiler for functions that are directly implemented via Hugr ops. - op: ops.OpType + args: + op: A function that takes an instantiation of the type arguments as well as + the monomorphic function type, and returns a concrete HUGR op. + """ + + op: Callable[[ht.FunctionType, Inst], ops.DataflowOp] - def __init__(self, op: ops.OpType) -> None: + def __init__(self, op: Callable[[ht.FunctionType, Inst], ops.DataflowOp]) -> None: self.op = op - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - node = self.graph.add_node( - self.op.model_copy(deep=True), inputs=args, parent=self.dfg.node - ) - return_ty = get_type(self.node) - return [node.add_out_port(ty) for ty in type_to_row(return_ty)] + def compile(self, args: list[Wire]) -> list[Wire]: + op = self.op(self.ty, self.type_args) + node = self.builder.add_op(op, *args) + return list(node) class NoopCompiler(CustomCallCompiler): """Call compiler for functions that are noops.""" - def compile(self, args: list[OutPortV]) -> list[OutPortV]: + def compile(self, args: list[Wire]) -> list[Wire]: return args diff --git a/guppylang/definition/declaration.py b/guppylang/definition/declaration.py index 7789318d..aff0743b 100644 --- a/guppylang/definition/declaration.py +++ b/guppylang/definition/declaration.py @@ -1,6 +1,11 @@ import ast from dataclasses import dataclass, field +from hugr import Node, Wire +from hugr import function as hf +from hugr import tys as ht +from hugr.dfg import OpVar, _DefinitionBuilder + from guppylang.ast_util import AstNode, has_empty_body, with_loc from guppylang.checker.core import Context, Globals from guppylang.checker.expr_checker import check_call, synthesize_call @@ -10,10 +15,9 @@ from guppylang.definition.function import PyFunc, parse_py_func from guppylang.definition.value import CallableDef, CompiledCallableDef from guppylang.error import GuppyError -from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV, VNode from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst -from guppylang.tys.ty import Type, type_to_row +from guppylang.tys.ty import Type @dataclass(frozen=True) @@ -68,9 +72,16 @@ def synthesize_call( node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst)) return node, ty - def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDecl": + def compile_outer( + self, module: _DefinitionBuilder[OpVar] + ) -> "CompiledFunctionDecl": """Adds a Hugr `FuncDecl` node for this function to the Hugr.""" - node = graph.add_declare(self.ty, parent, self.name) + assert isinstance( + module, hf.Module + ), "Functions can only be declared in modules" + module: hf.Module = module + + node = module.declare_function(self.name, self.ty.to_hugr_poly()) return CompiledFunctionDecl( self.id, self.name, @@ -86,30 +97,32 @@ def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDecl": class CompiledFunctionDecl(CheckedFunctionDecl, CompiledCallableDef): """A function declaration with a corresponding Hugr node.""" - hugr_node: VNode + declaration: Node def load_with_args( self, type_args: Inst, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, node: AstNode, - ) -> OutPortV: + ) -> Wire: """Loads the function as a value into a local Hugr dataflow graph.""" - return graph.add_load_function( - self.hugr_node.out_port(0), type_args, dfg.node - ).out_port(0) + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + return dfg.builder.load_function(self.declaration, func_ty, type_args) def compile_call( self, - args: list[OutPortV], + args: list[Wire], type_args: Inst, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, node: AstNode, - ) -> list[OutPortV]: + ) -> list[Wire]: """Compiles a call to the function.""" - call = graph.add_call(self.hugr_node.out_port(0), args, type_args, dfg.node) - return [call.out_port(i) for i in range(len(type_to_row(self.ty.output)))] + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + call = dfg.builder.call( + self.declaration, *args, instantiation=func_ty, type_args=type_args + ) + return list(call) diff --git a/guppylang/definition/extern.py b/guppylang/definition/extern.py index ba5d9e6b..867365f9 100644 --- a/guppylang/definition/extern.py +++ b/guppylang/definition/extern.py @@ -1,14 +1,14 @@ import ast from dataclasses import dataclass, field -from hugr.serialization import ops +from hugr import Node, Wire, val +from hugr.dfg import OpVar, _DefinitionBuilder from guppylang.ast_util import AstNode from guppylang.checker.core import Globals from guppylang.compiler.core import CompiledGlobals, DFContainer from guppylang.definition.common import CompilableDef, ParsableDef from guppylang.definition.value import CompiledValueDef, ValueDef -from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV, VNode from guppylang.tys.parsing import type_from_ast @@ -39,19 +39,22 @@ def parse(self, globals: Globals) -> "ExternDef": class ExternDef(RawExternDef, ValueDef, CompilableDef): """An extern symbol definition.""" - def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledExternDef": + def compile_outer(self, graph: _DefinitionBuilder[OpVar]) -> "CompiledExternDef": """Adds a Hugr constant node for the extern definition to the provided graph.""" + # The `typ` field must be serialized at this point, to ensure that the + # `Extension` is serializable. custom_const = { "symbol": self.symbol, - "typ": self.ty.to_hugr(), + "typ": self.ty.to_hugr().to_serial_root(), "constant": self.constant, } - value = ops.ExtensionValue( - extensions=["prelude"], + value = val.Extension( + name="ConstExternalSymbol", typ=self.ty.to_hugr(), - value=ops.CustomConst(c="ConstExternalSymbol", v=custom_const), + val=custom_const, + extensions=["prelude"], ) - const_node = graph.add_constant(ops.Value(value), self.ty, parent) + const_node = graph.add_const(value) return CompiledExternDef( self.id, self.name, @@ -68,10 +71,8 @@ def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledExternDef": class CompiledExternDef(ExternDef, CompiledValueDef): """An extern symbol definition that has been compiled to a Hugr constant.""" - const_node: VNode + const_node: Node - def load( - self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode - ) -> OutPortV: + def load(self, dfg: DFContainer, globals: CompiledGlobals, node: AstNode) -> Wire: """Loads the extern value into a local Hugr dataflow graph.""" - return graph.add_load_constant(self.const_node.out_port(0)).out_port(0) + return dfg.builder.load(self.const_node) diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index 0abf32b6..d7f8c459 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -5,6 +5,11 @@ from dataclasses import dataclass, field from typing import Any +import hugr.function as hf +import hugr.tys as ht +from hugr import Wire +from hugr.dfg import OpVar, _DefinitionBuilder + from guppylang.ast_util import AstNode, annotate_location, with_loc from guppylang.checker.cfg_checker import CheckedCFG from guppylang.checker.core import Context, Globals, Place, PyScope @@ -19,11 +24,10 @@ from guppylang.definition.common import CheckableDef, CompilableDef, ParsableDef from guppylang.definition.value import CallableDef, CompiledCallableDef from guppylang.error import GuppyError -from guppylang.hugr_builder.hugr import DFContainingVNode, Hugr, Node, OutPortV from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst -from guppylang.tys.ty import FunctionType, Type, type_to_row +from guppylang.tys.ty import FunctionType, Type PyFunc = Callable[..., Any] @@ -35,6 +39,13 @@ class RawFunctionDef(ParsableDef): The raw definition stores exactly what the user has written (i.e. the AST), without any additional checking or parsing. Furthermore, we store the values of the Python variables in scope at the point of definition. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + python_func: The Python function to be defined. + python_scope: The Python scope where the function was defined. """ python_func: PyFunc @@ -61,6 +72,14 @@ class ParsedFunctionDef(CheckableDef, CallableDef): In particular, this means that we have determined a type for the function and are ready to check the function body. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + docstring: The docstring of the function. """ python_scope: PyScope @@ -110,18 +129,29 @@ class CheckedFunctionDef(ParsedFunctionDef, CompilableDef): In particular, this means that we have a constructed and type checked a control-flow graph for the function body. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + docstring: The docstring of the function. + cfg: The type- and linearity-checked CFG for the function body. """ cfg: CheckedCFG[Place] - def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDef": + def compile_outer(self, module: _DefinitionBuilder[OpVar]) -> "CompiledFunctionDef": """Adds a Hugr `FuncDefn` node for this function to the Hugr. Note that we don't compile the function body at this point since we don't have access to the other compiled functions yet. The body is compiled later in `CompiledFunctionDef.compile_inner()`. """ - def_node = graph.add_def(self.ty, parent, self.name) + func_type = self.ty.to_hugr() + func_def = module.define_function(self.name, func_type.input) + func_def.declare_outputs(func_type.output) return CompiledFunctionDef( self.id, self.name, @@ -130,45 +160,58 @@ def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDef": self.python_scope, self.docstring, self.cfg, - def_node, + func_def, ) @dataclass(frozen=True) class CompiledFunctionDef(CheckedFunctionDef, CompiledCallableDef): - """A function definition with a corresponding Hugr node.""" + """A function definition with a corresponding Hugr node. + + Args: + id: The unique definition identifier. + name: The name of the function. + defined_at: The AST node where the function was defined. + ty: The type of the function. + python_scope: The Python scope where the function was defined. + docstring: The docstring of the function. + cfg: The type- and linearity-checked CFG for the function body. + func_def: The Hugr function definition. + """ - hugr_node: DFContainingVNode + func_def: hf.Function def load_with_args( self, type_args: Inst, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, node: AstNode, - ) -> OutPortV: + ) -> Wire: """Loads the function as a value into a local Hugr dataflow graph.""" - return graph.add_load_function( - self.hugr_node.out_port(0), type_args, dfg.node - ).out_port(0) + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + return dfg.builder.load_function(self.func_def, func_ty, type_args) def compile_call( self, - args: list[OutPortV], + args: list[Wire], type_args: Inst, dfg: DFContainer, - graph: Hugr, globals: CompiledGlobals, node: AstNode, - ) -> list[OutPortV]: + ) -> list[Wire]: """Compiles a call to the function.""" - call = graph.add_call(self.hugr_node.out_port(0), args, type_args, dfg.node) - return [call.out_port(i) for i in range(len(type_to_row(self.ty.output)))] + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + call = dfg.builder.call( + self.func_def, *args, instantiation=func_ty, type_args=type_args + ) + return list(call) - def compile_inner(self, graph: Hugr, globals: CompiledGlobals) -> None: + def compile_inner(self, globals: CompiledGlobals) -> None: """Compiles the body of the function.""" - compile_global_func_def(self, self.hugr_node, graph, globals) + compile_global_func_def(self, self.func_def, globals) def parse_py_func(f: PyFunc) -> tuple[ast.FunctionDef, str | None]: diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 98c5b02e..1552b4f1 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -3,9 +3,10 @@ import textwrap from collections.abc import Sequence from dataclasses import dataclass -from functools import cached_property from typing import Any +from hugr import Wire, ops + from guppylang.ast_util import AstNode, annotate_location from guppylang.checker.core import Globals from guppylang.definition.common import ( @@ -23,7 +24,6 @@ from guppylang.definition.parameter import ParamDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError, InternalGuppyError -from guppylang.hugr_builder.hugr import OutPortV from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter, check_all_args @@ -194,15 +194,14 @@ def check_instantiate( check_all_args(self.params, args, self.name, loc) return StructType(args, self) - @cached_property def generated_methods(self) -> list[CustomFunctionDef]: """Auto-generated methods for this struct.""" class ConstructorCompiler(CustomCallCompiler): """Compiler for the `__new__` constructor method of a struct.""" - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - return [self.graph.add_make_tuple(args).out_port(0)] + def compile(self, args: list[Wire]) -> list[Wire]: + return list(self.builder.add(ops.MakeTuple()(*args))) constructor_sig = FunctionType( inputs=[f.ty for f in self.fields], diff --git a/guppylang/definition/ty.py b/guppylang/definition/ty.py index f430fe7d..e50fe4fd 100644 --- a/guppylang/definition/ty.py +++ b/guppylang/definition/ty.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -from hugr.serialization import tys +from hugr import tys from guppylang.ast_util import AstNode from guppylang.definition.common import CompiledDef, Definition diff --git a/guppylang/definition/value.py b/guppylang/definition/value.py index 801763be..a7ac99ee 100644 --- a/guppylang/definition/value.py +++ b/guppylang/definition/value.py @@ -3,10 +3,11 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +from hugr import Wire + from guppylang.ast_util import AstNode from guppylang.definition.common import CompiledDef, Definition from guppylang.error import GuppyError -from guppylang.hugr_builder.hugr import Hugr, OutPortV from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import FunctionType, Type @@ -30,8 +31,8 @@ class CompiledValueDef(ValueDef, CompiledDef): @abstractmethod def load( - self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode - ) -> OutPortV: + self, dfg: "DFContainer", globals: "CompiledGlobals", node: AstNode + ) -> Wire: """Loads the defined value into a local Hugr dataflow graph.""" @@ -65,13 +66,12 @@ class CompiledCallableDef(CallableDef, CompiledValueDef): @abstractmethod def compile_call( self, - args: list[OutPortV], + args: list[Wire], type_args: Inst, dfg: "DFContainer", - graph: Hugr, globals: "CompiledGlobals", node: AstNode, - ) -> list[OutPortV]: + ) -> list[Wire]: """Compiles a call to the function.""" @abstractmethod @@ -79,17 +79,16 @@ def load_with_args( self, type_args: Inst, dfg: "DFContainer", - graph: Hugr, globals: "CompiledGlobals", node: AstNode, - ) -> OutPortV: + ) -> Wire: """Loads the function into a local Hugr dataflow graph. Requires an instantiation for all function parameters. """ def load( - self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode - ) -> OutPortV: + self, dfg: "DFContainer", globals: "CompiledGlobals", node: AstNode + ) -> Wire: """Loads the defined value into a local Hugr dataflow graph.""" - return self.load_with_args([], dfg, graph, globals, node) + return self.load_with_args([], dfg, globals, node) diff --git a/guppylang/error.py b/guppylang/error.py index b0feebce..d81c528d 100644 --- a/guppylang/error.py +++ b/guppylang/error.py @@ -106,11 +106,15 @@ def ipython_excepthook( yield ipython_shell.set_custom_exc((), None) except NameError: - # Otherwise, override the regular sys.excepthook - old_hook = sys.excepthook - sys.excepthook = hook - yield - sys.excepthook = old_hook + pass + else: + return + + # Otherwise, override the regular sys.excepthook + old_hook = sys.excepthook + sys.excepthook = hook + yield + sys.excepthook = old_hook def format_source_location( diff --git a/guppylang/hugr_builder/hugr.py b/guppylang/hugr_builder/hugr.py deleted file mode 100644 index a9b38e08..00000000 --- a/guppylang/hugr_builder/hugr.py +++ /dev/null @@ -1,865 +0,0 @@ -import itertools -from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Optional - -import networkx as nx # type: ignore[import-untyped] -from hugr import node_port -from hugr.serialization import ops, tys -from hugr.serialization import serial_hugr as raw -from hugr.serialization.ops import OpType - -from guppylang.tys.subst import Inst -from guppylang.tys.ty import ( - FunctionType, - StructType, - SumType, - TupleType, - Type, - TypeRow, - row_to_type, - rows_to_hugr, - type_to_row, -) - -NodeIdx = int -PortOffset = int - - -@dataclass(frozen=True) -class Port(ABC): - """Base class for ports on nodes.""" - - node: "Node" - offset: PortOffset | None - - -class InPort(Port, ABC): - """Base class for a port that incoming wires connect to.""" - - -class OutPort(Port, ABC): - """Base class for a port that outgoing wires come from.""" - - -@dataclass(frozen=True) -class InPortV(InPort): - """A typed value input port.""" - - ty: Type - offset: PortOffset - - -@dataclass(frozen=True) -class OutPortV(OutPort): - """A typed value output port.""" - - ty: Type - offset: PortOffset - - -@dataclass(frozen=True) -class InPortCF(InPort): - """A control-flow input port.""" - - # Control flow inputs are unordered so no port offset is needed - offset: None = field(default=None, init=False) - - -class OutPortCF(OutPort): - """A control-flow output port.""" - - -class DummyOp(OpType): - """A dummy Hugr op that is replaced with a call to a dummy function declaration - during serialisation. - - This is a placeholder for ops that aren't yet available in Hugr. - """ - - root: str # type: ignore[assignment] - - @property - def name(self) -> str: - """The name of the dummy op.""" - return self.root - - -Edge = tuple[OutPort, InPort] - -TypeList = list[Type] - - -@dataclass -class Node(ABC): - """Base class for a node in the graph. - - Has a number of input and output ports and an associated op type. - """ - - idx: NodeIdx - op: ops.OpType - parent: Optional["Node"] - meta_data: dict[str, Any] - - @property - @abstractmethod - def num_in_ports(self) -> int: - """The number of input ports on this node.""" - - @property - @abstractmethod - def num_out_ports(self) -> int: - """The number of output ports on this node.""" - - @abstractmethod - def in_port(self, offset: PortOffset | None) -> InPort: - """Returns the input port at the given offset.""" - - @abstractmethod - def out_port(self, offset: PortOffset | None) -> OutPort: - """Returns the output port at the given offset.""" - - def update_op(self) -> None: # noqa: B027 - """Updates the op type associated with this node with additional information. - - This should be called before serialisation. - """ - - @property - def in_ports(self) -> Iterator[InPort]: - """Returns an iterator over all input ports from left to right.""" - return (self.in_port(i) for i in range(self.num_in_ports)) - - @property - def out_ports(self) -> Iterator[OutPort]: - """Returns an iterator over all output ports from left to right.""" - return (self.out_port(i) for i in range(self.num_out_ports)) - - -@dataclass -class VNode(Node): - """A node with typed value ports.""" - - in_port_types: TypeList - out_port_types: TypeList - - @property - def num_in_ports(self) -> int: - """The number of input ports on this node.""" - return len(self.in_port_types) - - @property - def num_out_ports(self) -> int: - """The number of output ports on this node.""" - return len(self.out_port_types) - - def add_in_port(self, ty: Type) -> InPortV: - """Adds an input port at the end of the node and returns the port.""" - p = InPortV(self, self.num_in_ports, ty) - self.in_port_types.append(ty) - return p - - def add_out_port(self, ty: Type) -> OutPortV: - """Adds an output port at the end of the node and returns the port.""" - p = OutPortV(self, self.num_out_ports, ty) - self.out_port_types.append(ty) - return p - - def in_port(self, offset: PortOffset | None) -> InPortV: - """Returns the input port at the given offset.""" - assert offset is not None - assert offset < self.num_in_ports - assert offset != -1, "Cannot get the port of an order edge" - return InPortV(self, offset, self.in_port_types[offset]) - - def out_port(self, offset: PortOffset | None) -> OutPortV: - """Returns the output port at the given offset.""" - assert offset is not None - assert offset < self.num_out_ports - assert offset != -1, "Cannot get the port of an order edge" - return OutPortV(self, offset, self.out_port_types[offset]) - - @property - def in_ports(self) -> Iterator[InPortV]: - """Returns an iterator over all input ports from left to right.""" - return (self.in_port(i) for i in range(self.num_in_ports)) - - @property - def out_ports(self) -> Iterator[OutPortV]: - """Returns an iterator over all output ports from left to right.""" - return (self.out_port(i) for i in range(self.num_out_ports)) - - def update_op(self) -> None: - """Updates the operation associated with this node with type information. - - Feeds type information from the in- and out-ports to the operation class to - update signature information. This function must be called before serialisation. - """ - # We can't call `to_hugr()` on polymorphic function types, so we have to skip - # ops that have connected `Function` edges. - has_poly_func_edge = isinstance( - self.op.root, ops.FuncDecl | ops.FuncDefn | ops.Call | ops.LoadFunction - ) - if isinstance(self.op.root, ops.BaseOp) and not has_poly_func_edge: - in_types = [t.to_hugr() for t in self.in_port_types] - out_types = [t.to_hugr() for t in self.out_port_types] - self.op.root.insert_port_types(in_types, out_types) - super().update_op() - - -class CFNode(Node): - """A node in a control-flow graph. - - Compared to value nodes, the ports on this node are not typed since they correspond - to control-flow instead of data-flow. - """ - - _num_out_ports: int = 0 - - @property - def num_in_ports(self) -> int: - """The number of input ports on this node.""" - return 0 - - @property - def num_out_ports(self) -> int: - """The number of output ports on this node.""" - return self._num_out_ports - - def add_out_port(self) -> OutPortCF: - """Adds an output port at the end of the node and returns the port.""" - p = OutPortCF(self, self.num_out_ports) - self._num_out_ports += 1 - return p - - def in_port(self, offset: PortOffset | None) -> InPortCF: - assert offset is None - return InPortCF(self) - - def out_port(self, offset: PortOffset | None) -> OutPortCF: - """Returns the output port at the given offset.""" - assert offset is not None - assert offset < self.num_out_ports - return OutPortCF(self, offset) - - def update_op(self) -> None: - super().update_op() - - -class DFContainingNode(Node, ABC): - """Base class for a node whose children form a dataflow graph. - - Compared to a normal node, this node tracks the `Input` and `Output` nodes of its - child DFG which is required to compute the operation signature. - """ - - input_child: Optional["VNode"] = None # Input Node for the child dataflow graph - output_child: Optional["VNode"] = None # Output Node for the child dataflow graph - - def update_op(self) -> None: - """Updates the operation associated with this node with type information. - - Feeds type information from the signature of the contained dataflow graph to - the operation class to. This function must be called before serialisation. - """ - assert self.input_child is not None - assert self.output_child is not None - assert isinstance(self.op.root, ops.BaseOp) - # Input and output node may have extra order edges connected, so we filter - # `None`s here - ins = [ty.to_hugr() for ty in self.input_child.out_port_types] - outs = [ty.to_hugr() for ty in self.output_child.in_port_types] - self.op.root.insert_child_dfg_signature(inputs=ins, outputs=outs) - super().update_op() - - -class DFContainingVNode(VNode, DFContainingNode): - """A value node whose children form a dataflow graph""" - - -class BlockNode(DFContainingNode, CFNode): - """A `Block` node representing a basic block.""" - - -OrderEdge = tuple["Node", "Node"] -ORDER_EDGE_KEY = (-1, -1) - -UNDEFINED: node_port.NodeIdx = -1 - - -class Hugr: - """Hierarchical unified graph representation.""" - - name: str - root: VNode - _graph: nx.MultiDiGraph # TODO: We probably don't need networkx. - _children: dict[NodeIdx, list[Node]] - _default_parent: Node | None - - def __init__(self, name: str | None = None) -> None: - """Creates a new Hugr.""" - self.name = name or "Unnamed" - self._default_parent = None - self._graph = nx.MultiDiGraph() - self._children = {-1: []} - self.root = self.add_node( - op=ops.OpType(ops.Module(parent=UNDEFINED)), - meta_data={"name": name}, - parent=None, - ) - - @contextmanager - def parent(self, parent: Node) -> Iterator[None]: - """Context manager to set a default parent for adding new nodes.""" - old_default = self._default_parent - self._default_parent = parent - yield - self._default_parent = old_default - - def _insert_node(self, node: Node, inputs: list[OutPortV] | None = None) -> None: - """Helper method to insert a node into the graph datastructure.""" - self._graph.add_node(node.idx, data=node) - self._children[node.idx] = [] - self._children[node.parent.idx if node.parent else -1].append(node) - if inputs is not None: - for i, port in enumerate(inputs): - self.add_edge(port, node.in_port(i)) - - def add_node( - self, - op: ops.OpType, - input_types: TypeList | None = None, - output_types: TypeList | None = None, - parent: Node | None = None, - inputs: list[OutPortV] | None = None, - meta_data: dict[str, Any] | None = None, - ) -> VNode: - """Helper method to add a generic value node to the graph.""" - input_types = input_types or [] - output_types = output_types or [] - parent = parent or self._default_parent - node = VNode( - idx=self._graph.number_of_nodes(), - op=op, - parent=parent, - in_port_types=[p.ty for p in inputs] if inputs is not None else input_types, - out_port_types=output_types, - meta_data=meta_data or {}, - ) - self._insert_node(node, inputs) - return node - - def _add_dfg_node( - self, - op: ops.OpType, - input_types: TypeList | None = None, - output_types: TypeList | None = None, - parent: Node | None = None, - inputs: list[OutPortV] | None = None, - meta_data: dict[str, Any] | None = None, - ) -> DFContainingVNode: - """Helper method to add a generic dataflow containing value node to the - graph.""" - input_types = input_types or [] - output_types = output_types or [] - parent = parent or self._default_parent - node = DFContainingVNode( - idx=self._graph.number_of_nodes(), - op=op, - parent=parent, - in_port_types=[p.ty for p in inputs] if inputs is not None else input_types, - out_port_types=output_types, - meta_data=meta_data or {}, - ) - self._insert_node(node, inputs) - return node - - def set_root_name(self, name: str) -> VNode: - """Sets the name of the root node.""" - self.root.meta_data["name"] = name - return self.root - - def add_constant( - self, value: ops.Value, ty: Type, parent: Node | None = None - ) -> VNode: - """Adds a constant node holding a given value to the graph.""" - return self.add_node( - ops.OpType(ops.Const(v=value, parent=UNDEFINED)), [], [ty], parent, None - ) - - def add_input( - self, output_tys: TypeList | None = None, parent: Node | None = None - ) -> VNode: - """Adds an `Input` node to the graph.""" - parent = parent or self._default_parent - node = self.add_node( - ops.OpType(ops.Input(parent=UNDEFINED)), [], output_tys, parent - ) - if isinstance(parent, DFContainingNode): - parent.input_child = node - return node - - def add_input_with_ports( - self, output_tys: Sequence[Type], parent: Node | None = None - ) -> tuple[VNode, list[OutPortV]]: - """Adds an `Input` node to the graph.""" - node = self.add_input(None, parent) - ports = [node.add_out_port(ty) for ty in output_tys] - return node, ports - - def add_output( - self, - inputs: list[OutPortV] | None = None, - input_tys: TypeList | None = None, - parent: Node | None = None, - ) -> VNode: - """Adds an `Output` node to the graph.""" - parent = parent or self._default_parent - node = self.add_node( - ops.OpType(ops.Output(parent=UNDEFINED)), input_tys, [], parent, inputs - ) - if isinstance(parent, DFContainingNode): - parent.output_child = node - return node - - def add_block(self, parent: Node | None, num_successors: int = 0) -> BlockNode: - """Adds a `Block` node to the graph.""" - node = BlockNode( - idx=self._graph.number_of_nodes(), - op=ops.OpType(ops.DataflowBlock(parent=UNDEFINED, sum_rows=[])), - parent=parent, - meta_data={}, - ) - self._insert_node(node) - for _ in range(num_successors): - node.add_out_port() - return node - - def add_exit(self, output_tys: TypeList, parent: Node) -> CFNode: - """Adds an `Exit` node to the graph.""" - outputs = [ty.to_hugr() for ty in output_tys] - node = CFNode( - idx=self._graph.number_of_nodes(), - op=ops.OpType(ops.ExitBlock(cfg_outputs=outputs, parent=UNDEFINED)), - parent=parent, - meta_data={}, - ) - self._insert_node(node) - return node - - def add_dfg(self, parent: Node) -> DFContainingVNode: - """Adds a nested dataflow `DFG` node to the graph.""" - return self._add_dfg_node(ops.OpType(ops.DFG(parent=UNDEFINED)), [], [], parent) - - def add_case(self, parent: Node) -> DFContainingVNode: - """Adds a `Case` node to the graph.""" - return self._add_dfg_node( - ops.OpType(ops.Case(parent=UNDEFINED)), [], [], parent - ) - - def add_cfg(self, parent: Node, inputs: list[OutPortV]) -> VNode: - """Adds a nested control-flow `CFG` node to the graph.""" - return self.add_node( - ops.OpType(ops.CFG(parent=UNDEFINED)), [], [], parent, inputs - ) - - def add_conditional( - self, - cond_input: OutPortV, - inputs: list[OutPortV], - parent: Node | None = None, - ) -> VNode: - """Adds a `Conditional` node to the graph.""" - inputs = [cond_input, *inputs] - return self.add_node( - ops.OpType(ops.Conditional(sum_rows=[], parent=UNDEFINED)), - None, - None, - parent, - inputs, - ) - - def add_tail_loop( - self, inputs: list[OutPortV], parent: Node | None = None - ) -> DFContainingVNode: - """Adds a `TailLoop` node to the graph.""" - return self._add_dfg_node( - ops.OpType(ops.TailLoop(parent=UNDEFINED)), None, None, parent, inputs - ) - - def add_make_tuple( - self, inputs: list[OutPortV], parent: Node | None = None - ) -> VNode: - """Adds a `MakeTuple` node to the graph.""" - ty = TupleType([port.ty for port in inputs]) - return self.add_node( - ops.OpType(ops.MakeTuple(parent=UNDEFINED)), None, [ty], parent, inputs - ) - - def add_unpack_tuple( - self, input_tuple: OutPortV, parent: Node | None = None - ) -> VNode: - """Adds an `UnpackTuple` node to the graph.""" - match input_tuple.ty: - case TupleType(element_types=elems): - tys = list(elems) - case StructType(fields=fields): - tys = [field.ty for field in fields] - case ty: - raise AssertionError(f"Cannot unpack `{ty}`") - return self.add_node( - ops.OpType(ops.UnpackTuple(parent=UNDEFINED)), - None, - tys, - parent, - [input_tuple], - ) - - def add_tag( - self, - variants: Sequence[TypeRow], - tag: int, - inputs: list[OutPortV], - parent: Node | None = None, - ) -> VNode: - """Adds a `Tag` node to the graph.""" - assert all(inp.ty == ty for inp, ty in zip(inputs, variants[tag], strict=True)) - hugr_variants = rows_to_hugr(variants) - out_ty = SumType([row_to_type(row) for row in variants]) - return self.add_node( - ops.OpType(ops.Tag(tag=tag, variants=hugr_variants, parent=UNDEFINED)), - None, - [out_ty], - parent, - inputs, - ) - - def add_load_constant( - self, const_port: OutPortV, parent: Node | None = None - ) -> VNode: - """Adds a `LoadConstant` node to the graph.""" - return self.add_node( - ops.OpType( - ops.LoadConstant(datatype=const_port.ty.to_hugr(), parent=UNDEFINED) - ), - None, - [const_port.ty], - parent, - [const_port], - ) - - def add_load_function( - self, def_port: OutPortV, inst: Inst, parent: Node | None = None - ) -> VNode: - """Adds a `LoadFunction` node to the graph.""" - assert isinstance(def_port.ty, FunctionType) - assert len(def_port.ty.params) == len(inst) - instantiation = def_port.ty.instantiate(inst) - op = ops.LoadFunction( - func_sig=def_port.ty.to_hugr_poly(), - type_args=[arg.to_hugr() for arg in inst], - signature=tys.FunctionType(input=[], output=[instantiation.to_hugr()]), - parent=UNDEFINED, - ) - return self.add_node(ops.OpType(op), None, [instantiation], parent, [def_port]) - - def add_call( - self, - def_port: OutPortV, - args: list[OutPortV], - inst: Inst, - parent: Node | None = None, - ) -> VNode: - """Adds a `Call` node to the graph.""" - assert isinstance(def_port.ty, FunctionType) - instantiation = def_port.ty.instantiate(inst) - op = ops.Call( - func_sig=def_port.ty.to_hugr_poly(), - type_args=[arg.to_hugr() for arg in inst], - instantiation=instantiation.to_hugr().root, - parent=UNDEFINED, - ) - return self.add_node( - ops.OpType(op), - None, - list(type_to_row(instantiation.output)), - parent, - [*args, def_port], - ) - - def add_indirect_call( - self, fun_port: OutPortV, args: list[OutPortV], parent: Node | None = None - ) -> VNode: - """Adds an `IndirectCall` node to the graph.""" - assert isinstance(fun_port.ty, FunctionType) - - return self.add_node( - ops.OpType(ops.CallIndirect(parent=UNDEFINED)), - None, - list(type_to_row(fun_port.ty.output)), - parent, - [fun_port, *args], - ) - - def add_partial( - self, def_port: OutPortV, inputs: list[OutPortV], parent: Node | None = None - ) -> VNode: - """Adds a `Partial` evaluation node to the graph.""" - assert isinstance(def_port.ty, FunctionType) - assert len(def_port.ty.inputs) >= len(inputs) - assert [p.ty.to_hugr() for p in inputs] == [ - ty.to_hugr() for ty in def_port.ty.inputs[: len(inputs)] - ] - new_ty = FunctionType( - def_port.ty.inputs[len(inputs) :], - def_port.ty.output, - def_port.ty.input_names[len(inputs) :] - if def_port.ty.input_names is not None - else None, - ) - return self.add_node( - DummyOp("partial"), None, [new_ty], parent, [*inputs, def_port] - ) - - def add_def( - self, fun_ty: FunctionType, parent: Node | None, name: str - ) -> DFContainingVNode: - """Adds a `FucnDefn` node to the graph.""" - op = ops.FuncDefn(name=name, signature=fun_ty.to_hugr_poly(), parent=UNDEFINED) - return self._add_dfg_node(ops.OpType(op), [], [fun_ty], parent, None) - - def add_declare(self, fun_ty: FunctionType, parent: Node, name: str) -> VNode: - """Adds a `FuncDecl` node to the graph.""" - op = ops.FuncDecl(name=name, signature=fun_ty.to_hugr_poly(), parent=UNDEFINED) - return self.add_node(ops.OpType(op), [], [fun_ty], parent, None) - - def add_edge(self, src_port: OutPort, tgt_port: InPort) -> None: - """Adds an edge between two ports.""" - if isinstance(src_port, OutPortV) or isinstance(tgt_port, InPortV): - assert isinstance(src_port, OutPortV) - assert isinstance(tgt_port, InPortV) - assert src_port.ty == tgt_port.ty - else: - assert isinstance(src_port, OutPortCF) - assert isinstance(tgt_port, InPortCF) - self._graph.add_edge( - src_port.node.idx, tgt_port.node.idx, key=(src_port.offset, tgt_port.offset) - ) - - def add_order_edge(self, src: Node, tgt: Node) -> None: - """Adds a order-edge between two nodes.""" - self._graph.add_edge(src.idx, tgt.idx, key=ORDER_EDGE_KEY) - - def nodes(self) -> Iterator[Node]: - """Returns an iterator over all nodes in the graph.""" - return (n["data"] for n in self._graph.nodes.values()) - - def get_node(self, idx: int) -> Node: - """Returns the node corresponding to given index.""" - return self._graph.nodes[idx]["data"] # type: ignore[no-any-return] - - def children(self, node: Node) -> list[Node]: - """Returns list of a node's immediate children in the hierarchy.""" - return self._children[node.idx] - - def top_level_nodes(self) -> list[Node]: - """Returns list of nodes at the top level of the hierarchy. - - These are nodes that do not have a parent. Usually this will just - be the `Root` node. - """ - return self._children[-1] - - def edges(self) -> Iterator[Edge]: - """Returns an iterator over all edges in the graph.""" - return ( - self._to_edge(*e) - for e in self._graph.edges(keys=True) - if e[2] != ORDER_EDGE_KEY - ) - - def order_edges(self) -> Iterator[OrderEdge]: - """Returns an iterator over all order-edges in the graph.""" - return ( - (self.get_node(e[0]), self.get_node(e[1])) - for e in self._graph.edges(keys=True) - if e[2] == ORDER_EDGE_KEY - ) - - def in_edges(self, port: InPort) -> Iterator[Edge]: - """Returns an iterator over all edges connected to a given in-port.""" - for e in self._graph.in_edges(port.node.idx, keys=True): - if e[2] == ORDER_EDGE_KEY: - continue - src, tgt = self._to_edge(*e) - if tgt.offset == port.offset: - yield src, tgt - - def out_edges(self, port: OutPort) -> Iterator[Edge]: - """Returns an iterator over all edges originating from a given out-port.""" - for e in self._graph.out_edges(port.node.idx, keys=True): - if e[2] == ORDER_EDGE_KEY: - continue - src, tgt = self._to_edge(*e) - if src.offset == port.offset: - yield src, tgt - - def order_successors(self, node: Node) -> Iterator[Node]: - """Returns an iterator over all nodes that this node connects to via an - order edge.""" - for _src, tgt, key in self._graph.out_edges(node.idx, keys=True): - if key == ORDER_EDGE_KEY: - yield tgt - - def order_predecessors(self, node: Node) -> Iterator[Node]: - """Returns an iterator over all nodes that are connected to this node via an - order edge.""" - for src, _tgt, key in self._graph.in_edges(node.idx, keys=True): - if key == ORDER_EDGE_KEY: - yield src - - def _to_edge(self, src: int, tgt: int, key: tuple[int, int]) -> Edge: - src_node = self.get_node(src) - tgt_node = self.get_node(tgt) - return src_node.out_port(key[0]), tgt_node.in_port(key[1]) - - def remove_edge(self, src_port: OutPort, tgt_port: InPort) -> None: - """Removes an edge from the graph.""" - self._graph.remove_edge( - src_port.node.idx, tgt_port.node.idx, key=(src_port.offset, tgt_port.offset) - ) - - def remove_dummy_nodes(self) -> "Hugr": - """Replaces dummy ops with external function calls.""" - if self.root is None: - raise ValueError("Dummy node removal requires a module root node") - used_names: dict[str, int] = {} - for n in list(self.nodes()): - if isinstance(n, VNode) and isinstance(n.op, DummyOp): - name = n.op.name - fun_ty = FunctionType( - list(n.in_port_types), row_to_type(n.out_port_types) - ) - if name in used_names: - used_names[name] += 1 - name = f"{name}${used_names[name]}" - else: - used_names[name] = 0 - decl = self.add_declare(fun_ty, self.root, name) - n.op = ops.OpType( - ops.Call( - func_sig=fun_ty.to_hugr_poly(), - type_args=[], - instantiation=fun_ty.to_hugr().root, - parent=UNDEFINED, - ) - ) - self.add_edge(decl.out_port(0), n.add_in_port(fun_ty)) - return self - - def insert_order_edges(self) -> "Hugr": - """Adds order edges to the source and target inter-graph edges. - - This ensures that the source is executed before the target. This action must be - performed before serialisation. - """ - for src, tgt in list(self.edges()): - # Exclude CF and constant edges - if isinstance(src, OutPortCF) or isinstance( - src.node.op.root, ops.FuncDecl | ops.FuncDefn | ops.Const - ): - continue - - if src.node.parent != tgt.node.parent: - # Walk up the hierarchy from the tgt until we hit a node at the same - # level as src - node = tgt.node - while node.parent != src.node.parent: - if node.parent is None: - raise ValueError("Invalid non-local edge!") - node = node.parent - # Add order edge to make sure that the src is executed first - self.add_order_edge(src.node, node) - return self - - def to_raw(self) -> raw.SerialHugr: - """Returns the raw representation of this HUGR for serialisation.""" - if self.root is None: - raise ValueError("Serial Hugr requires a root node") - - self.remove_dummy_nodes() - self.insert_order_edges() - # Hugr requires that Input/Output nodes are the first/second children in a DFG. - # Furthermore, exit nodes must be the second children of CFGs. We're going to - # satisfy this trivially by first serialising all inputs, outputs, entry and - # exit nodes - input_nodes: list[Node] = [] - output_nodes: list[Node] = [] - entry_nodes: list[Node] = [] - exit_nodes: list[Node] = [] - remaining_nodes: list[Node] = [] - indices = itertools.count() - raw_index: dict[int, node_port.NodeIdx] = {} - all_nodes = self.nodes() - for n in all_nodes: - if n is self.root: - continue - match n.op.root: - case ops.Input(): - input_nodes.append(n) - case ops.Output(): - output_nodes.append(n) - # We can detect entry BBs by looking for BBs without incoming edges - # since Guppy will never generate an edge pointing back to the entry. - # Also, Guppy errors on unreachable code, so we will never generate - # interior BBs without incoming edges. Hence, there are also no false - # positives. - case ops.DataflowBlock() if next( - self.in_edges(n.in_port(None)), None - ) is None: - entry_nodes.append(n) - case ops.ExitBlock(): - exit_nodes.append(n) - case _: - remaining_nodes.append(n) - for n in itertools.chain( - iter([self.root]), - iter(entry_nodes), - iter(exit_nodes), - iter(input_nodes), - iter(output_nodes), - iter(remaining_nodes), - ): - raw_index[n.idx] = next(indices) - - nodes: list[ops.OpType] = [ - ops.OpType(ops.Module(parent=UNDEFINED)) - ] * self._graph.number_of_nodes() - for n in self.nodes(): - idx = raw_index[n.idx] - # Nodes without parent have themselves as parent in the serialised format - parent = n.parent or n - n.update_op() - n.op.root.parent = raw_index[parent.idx] - nodes[idx] = n.op - - edges: list[raw.Edge] = [] - for src, tgt in self.edges(): - edges.append( - ( - (raw_index[src.node.idx], src.offset), - (raw_index[tgt.node.idx], tgt.offset), - ) - ) - - for src, tgt in self.order_edges(): - edges.append(((raw_index[src.idx], None), (raw_index[tgt.idx], None))) - - return raw.SerialHugr(nodes=nodes, edges=edges) - - def serialize(self) -> str: - """Serialize this Hugr in JSON format.""" - return self.to_raw().to_json() diff --git a/guppylang/hugr_builder/visualise.py b/guppylang/hugr_builder/visualise.py deleted file mode 100644 index 988f91be..00000000 --- a/guppylang/hugr_builder/visualise.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Visualise HUGR using graphviz.""" - -import ast -from collections.abc import Iterable -from typing import TYPE_CHECKING - -import graphviz as gv # type: ignore[import-untyped] - -from guppylang.cfg.analysis import ( - DefAssignmentDomain, - LivenessDomain, - MaybeAssignmentDomain, -) -from guppylang.cfg.bb import BB, VId -from guppylang.hugr_builder.hugr import DummyOp, Hugr, InPort, Node, OutPort, OutPortV - -if TYPE_CHECKING: - from guppylang.cfg.cfg import CFG - -# old palettte: https://colorhunt.co/palette/343a407952b3ffc107e1e8eb -# _COLOURS = { -# "background": "white", -# "node": "#7952B3", -# "edge": "#FFC107", -# "dark": "#343A40", -# "const": "#7c55b4", -# "discard": "#ff8888", -# "node_border": "#9d80c7", -# "port_border": "#ffd966", -# } - -# ZX colours -# _COLOURS = { -# "background": "white", -# "node": "#629DD1", -# "edge": "#297FD5", -# "dark": "#112D4E", -# "const": "#a1eea1", -# "discard": "#ff8888", -# "node_border": "#D8F8D8", -# "port_border": "#E8A5A5", -# } - -# Conference talk colours -_COLOURS = { - "background": "white", - "node": "#ACCBF9", - "edge": "#1CADE4", - "dark": "black", - "const": "#77CEEF", - "discard": "#ff8888", - "node_border": "white", - "port_border": "#1CADE4", -} - - -_FONTFACE = "monospace" - -_HTML_LABEL_TEMPLATE = """ - - {inputs_row} - - - - {outputs_row} -
- - -
{node_label}{node_data}
-
-""" - - -def _format_html_label(**kwargs: str) -> str: - _HTML_LABEL_DEFAULTS = { - "label_color": _COLOURS["dark"], - "node_back_color": _COLOURS["node"], - "inputs_row": "", - "outputs_row": "", - "border_colour": _COLOURS["port_border"], - "border_width": "1", - "fontface": _FONTFACE, - "fontsize": 11.0, - } - return _HTML_LABEL_TEMPLATE.format(**{**_HTML_LABEL_DEFAULTS, **kwargs}) - - -_HTML_PORTS_ROW_TEMPLATE = """ - - - - - {port_cells} - -
- - -""" - -_HTML_PORT_TEMPLATE = ( - '' - '{port}' -) - -_INPUT_PREFIX = "in." -_OUTPUT_PREFIX = "out." - - -def _html_ports(ports: Iterable[str], id_prefix: str) -> str: - return _HTML_PORTS_ROW_TEMPLATE.format( - port_cells="".join( - _HTML_PORT_TEMPLATE.format( - port=port, - # differentiate input and output node identifiers - # with a prefix - port_id=id_prefix + port, - back_colour=_COLOURS["background"], - font_colour=_COLOURS["dark"], - border_width="1", - border_colour=_COLOURS["port_border"], - fontface=_FONTFACE, - ) - for port in ports - ) - ) - - -def _in_port_name(p: InPort) -> str: - return ( - f"{p.node.idx}:{_INPUT_PREFIX}{p.offset}" - if p.offset is not None - else str(p.node.idx) - ) - - -def _out_port_name(p: OutPort) -> str: - return ( - f"{p.node.idx}:{_OUTPUT_PREFIX}{p.offset}" - if p.offset is not None - else str(p.node.idx) - ) - - -def _in_order_name(n: Node) -> str: - return f"{n.idx}:{_INPUT_PREFIX}None" - - -def _out_order_name(n: Node) -> str: - return f"{n.idx}:{_OUTPUT_PREFIX}None" - - -def viz_node(node: Node, hugr: Hugr, graph: gv.Digraph) -> None: - in_ports = [str(i) for i in range(node.num_in_ports)] - out_ports = [str(i) for i in range(node.num_out_ports)] - if len(node.meta_data) > 0: - data = "

" + "
".join( - f"{key}: {value}" for key, value in node.meta_data.items() - ) - else: - data = "" - if len(hugr.children(node)) > 0: - with graph.subgraph(name=f"cluster{node.idx}") as sub: - for child in hugr.children(node): - viz_node(child, hugr, sub) - html_label = _format_html_label( - node_back_color=_COLOURS["edge"], - node_label=node.op.root.display_name(), - node_data=data, - border_colour=_COLOURS["port_border"], - inputs_row=_html_ports(in_ports, _INPUT_PREFIX) - if len(in_ports) > 0 - else "", - outputs_row=_html_ports(out_ports, _OUTPUT_PREFIX) - if len(out_ports) > 0 - else "", - ) - sub.node(f"{node.idx}", shape="plain", label=f"<{html_label}>") - sub.attr(label="", margin="10", color=_COLOURS["edge"]) - else: - html_label = _format_html_label( - node_back_color=_COLOURS["node"], - node_label=node.op.name - if isinstance(node.op, DummyOp) - else node.op.root.display_name(), - node_data=data, - inputs_row=_html_ports(in_ports, _INPUT_PREFIX) - if len(in_ports) > 0 - else "", - outputs_row=_html_ports(out_ports, _OUTPUT_PREFIX) - if len(out_ports) > 0 - else "", - border_colour=_COLOURS["background"], - ) - graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain") - - -def hugr_to_graphviz(hugr: Hugr) -> gv.Digraph: - graph_atrr = { - "rankdir": "", - "ranksep": "0.1", - "nodesep": "0.15", - "margin": "0", - "bgcolor": _COLOURS["background"], - } - graph = gv.Digraph(hugr.name, strict=False) - graph.attr(**graph_atrr) - for node in hugr.top_level_nodes(): - viz_node(node, hugr, graph) - edge_attr = { - "penwidth": "1.5", - "arrowhead": "none", - "arrowsize": "1.0", - "fontname": _FONTFACE, - "fontsize": "9", - "fontcolor": "black", - } - for src_port, tgt_port in hugr.edges(): - graph.edge( - _out_port_name(src_port), - _in_port_name(tgt_port), - label=str(src_port.ty) if isinstance(src_port, OutPortV) else "", - color=_COLOURS["edge"] - if isinstance(src_port, OutPortV) - else _COLOURS["dark"], - **edge_attr, - ) - for src, tgt in hugr.order_edges(): - graph.edge( - _out_order_name(src), - _in_order_name(tgt), - label="", - color=_COLOURS["dark"], - **edge_attr, - ) - return graph - - -def render_hugr(hugr: Hugr, filename: str, format_st: str = "svg") -> None: - gv_graph = hugr_to_graphviz(hugr) - gv_graph.render(filename, format=format_st) - - -def commas(*args: str) -> str: - return ", ".join(args) - - -def cfg_to_graphviz( - cfg: "CFG", - live_before: dict[BB, LivenessDomain[VId]], - ass_before: dict[BB, DefAssignmentDomain[VId]], - maybe_ass_before: dict[BB, MaybeAssignmentDomain[VId]], -) -> gv.Digraph: - graph = gv.Digraph("CFG", strict=False) - for bb in cfg.bbs: - label = f""" -assigned: {commas(*(str(x) for x in bb.vars.assigned))} -used: {commas(*(str(x) for x in bb.vars.used))} -maybe_ass_before: {commas(*(str(x) for x in maybe_ass_before[bb]))} -ass_before: {commas(*(str(x) for x in ass_before[bb]))} -live_before: {commas(*(str(x) for x in live_before[bb]))} --------- -""" + "\n".join(ast.unparse(s) for s in bb.statements) - if bb.branch_pred is not None: - label += f"\n{ast.unparse(bb.branch_pred)} ?" - graph.node(str(bb.idx), label, shape="rect") - for succ in bb.successors: - graph.edge(str(bb.idx), str(succ.idx)) - return graph - - -def render_cfg( - cfg: "CFG", - live_before: dict[BB, LivenessDomain[VId]], - ass_before: dict[BB, DefAssignmentDomain[VId]], - maybe_ass_before: dict[BB, MaybeAssignmentDomain[VId]], - filename: str, - format_st: str = "svg", -) -> None: - gv_graph = cfg_to_graphviz(cfg, live_before, ass_before, maybe_ass_before) - gv_graph.render(filename, format=format_st) diff --git a/guppylang/module.py b/guppylang/module.py index 934c2b99..de9fb387 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -5,6 +5,9 @@ from types import ModuleType from typing import Any, Union +from hugr import Hugr, ops +from hugr.function import Module + from guppylang.checker.core import Globals, PyScope from guppylang.compiler.core import CompiledGlobals from guppylang.definition.common import ( @@ -22,7 +25,6 @@ from guppylang.definition.struct import CheckedStructDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError, pretty_errors -from guppylang.hugr_builder.hugr import Hugr PyFunc = Callable[..., Any] PyFuncDefOrDecl = tuple[bool, PyFunc] @@ -38,7 +40,7 @@ class GuppyModule: # If the hugr has already been compiled, keeps a reference that can be returned # from `compile`. - _compiled_hugr: Hugr | None + _compiled_hugr: Hugr[ops.Module] | None # Map of raw definitions in this module _raw_defs: dict[DefId, RawDef] @@ -169,12 +171,16 @@ def _check_defs( } @pretty_errors - def compile(self) -> Hugr: + def compile(self) -> Hugr[ops.Module]: """Compiles the module and returns the final Hugr.""" if self.compiled: assert self._compiled_hugr is not None, "Module is compiled but has no Hugr" return self._compiled_hugr + # Prepare Hugr for this module + graph = Module() + graph.metadata["name"] = self.name + # Type definitions need to be checked first so that we can use them when parsing # function signatures etc. type_defs = self._check_defs( @@ -187,7 +193,7 @@ def compile(self) -> Hugr: for defn in type_defs.values(): if isinstance(defn, CheckedStructDef): self._globals.impls.setdefault(defn.id, {}) - for method_def in defn.generated_methods: + for method_def in defn.generated_methods(): generated[method_def.id] = method_def self._globals.impls[defn.id][method_def.name] = method_def.id @@ -197,16 +203,10 @@ def compile(self) -> Hugr: ) self._globals = self._globals.update_defs(other_defs) - # Prepare Hugr for this module - graph = Hugr(self.name) - module_node = graph.set_root_name(self.name) - # Compile definitions to Hugr self._compiled_globals = { defn.id: ( - defn.compile_outer(graph, module_node) - if isinstance(defn, CompilableDef) - else defn + defn.compile_outer(graph) if isinstance(defn, CompilableDef) else defn ) for defn in itertools.chain(type_defs.values(), other_defs.values()) } @@ -215,11 +215,12 @@ def compile(self) -> Hugr: # Finally, compile the definition contents to Hugr. For example, this compiles # the bodies of functions. for defn in self._compiled_globals.values(): - defn.compile_inner(graph, all_compiled_globals) + defn.compile_inner(all_compiled_globals) + hugr = graph.hugr self._compiled = True - self._compiled_hugr = graph - return graph + self._compiled_hugr = hugr + return hugr def contains(self, name: str) -> bool: """Returns 'True' if the module contains an object with the given name.""" diff --git a/guppylang/hugr_builder/__init__.py b/guppylang/prelude/_internal/__init__.py similarity index 100% rename from guppylang/hugr_builder/__init__.py rename to guppylang/prelude/_internal/__init__.py diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal/checker.py similarity index 50% rename from guppylang/prelude/_internal.py rename to guppylang/prelude/_internal/checker.py index 70558bf0..f6c59995 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal/checker.py @@ -1,9 +1,6 @@ import ast -from hugr.serialization import ops, tys -from pydantic import BaseModel - -from guppylang.ast_util import AstNode, get_type, with_loc +from guppylang.ast_util import AstNode, with_loc from guppylang.checker.core import Context from guppylang.checker.expr_checker import ( ExprSynthesizer, @@ -14,22 +11,14 @@ ) from guppylang.definition.custom import ( CustomCallChecker, - CustomCallCompiler, CustomFunctionDef, DefaultCallChecker, ) from guppylang.definition.value import CallableDef from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall, ResultExpr from guppylang.tys.arg import ConstArg, TypeArg -from guppylang.tys.builtin import ( - bool_type, - int_type, - is_array_type, - is_bool_type, - list_type, -) +from guppylang.tys.builtin import bool_type, int_type, is_array_type, is_bool_type from guppylang.tys.const import Const, ConstValue from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( @@ -37,101 +26,10 @@ NoneType, NumericType, Type, - type_to_row, unify, ) -class ConstInt(BaseModel): - """Hugr representation of signed and unsigned integers in the arithmetic extension. - - Hugr always uses a u64 for the value. The interpretation is: - - as an unsigned integer, (value mod `2^N`); - - as a signed integer, (value mod `2^(N-1) - 2^(N-1)*a`) - where `N = 2^log_width` and `a` is the (N-1)th bit of `x` (counting from 0 = least - significant bit). - """ - - log_width: int - value: int - - -class ConstF64(BaseModel): - """Hugr representation of float values.""" - - value: float - - -def bool_value(b: bool) -> ops.Value: - """Returns the Hugr representation of a boolean value.""" - return ops.Value( - ops.SumValue(tag=int(b), typ=tys.SumType(tys.UnitSum(size=2)), vs=[]) - ) - - -def int_value(i: int) -> ops.Value: - """Returns the Hugr representation of an integer value.""" - return ops.Value( - ops.ExtensionValue( - extensions=["arithmetic.int.types"], - typ=NumericType(NumericType.Kind.Int).to_hugr(), - value=ops.CustomConst( - c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i) - ), - ) - ) - - -def float_value(f: float) -> ops.Value: - """Returns the Hugr representation of a float value.""" - return ops.Value( - ops.ExtensionValue( - extensions=["arithmetic.float.types"], - typ=NumericType(NumericType.Kind.Float).to_hugr(), - value=ops.CustomConst(c="ConstF64", v=ConstF64(value=f)), - ) - ) - - -def list_value(v: list[ops.Value], ty: Type) -> ops.Value: - """Returns the Hugr representation of a list value.""" - return ops.Value( - ops.ExtensionValue( - extensions=["Collections"], - typ=list_type(ty).to_hugr(), - value=ops.CustomConst(c="ListValue", v=(v, ty.to_hugr())), - ) - ) - - -def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType: - """Utility method to create Hugr logic ops.""" - return ops.OpType( - ops.CustomOp(extension="logic", name=op_name, args=args or [], parent=UNDEFINED) - ) - - -def int_op( - op_name: str, - ext: str = "arithmetic.int", - args: list[tys.TypeArg] | None = None, - num_params: int = 1, -) -> ops.OpType: - """Utility method to create Hugr integer arithmetic ops.""" - if args is None: - args = num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))] - return ops.OpType( - ops.CustomOp(extension=ext, name=op_name, args=args, parent=UNDEFINED) - ) - - -def float_op(op_name: str, ext: str = "arithmetic.float") -> ops.OpType: - """Utility method to create Hugr integer arithmetic ops.""" - return ops.OpType( - ops.CustomOp(extension=ext, name=op_name, args=[], parent=UNDEFINED) - ) - - class CoercingChecker(DefaultCallChecker): """Function call type checker that automatically coerces arguments to float.""" @@ -160,13 +58,13 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None: def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: expr, subst = self.base_checker.check(args, ty) if isinstance(expr, GlobalCall): - expr.args = list(reversed(args)) + expr.args = list(reversed(expr.args)) return expr, subst def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: expr, ty = self.base_checker.synthesize(args) if isinstance(expr, GlobalCall): - expr.args = list(reversed(args)) + expr.args = list(reversed(expr.args)) return expr, ty @@ -320,143 +218,3 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: @staticmethod def _is_numeric_or_bool_type(ty: Type) -> bool: return isinstance(ty, NumericType) or is_bool_type(ty) - - -class NatTruedivCompiler(CustomCallCompiler): - """Compiler for the `nat.__truediv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float, Nat - - # Compile `truediv` using float arithmetic - [left, right] = args - [left] = Nat.__float__.compile_call( - [left], [], self.dfg, self.graph, self.globals, self.node - ) - [right] = Nat.__float__.compile_call( - [right], [], self.dfg, self.graph, self.globals, self.node - ) - [out] = Float.__truediv__.compile_call( - [left, right], [], self.dfg, self.graph, self.globals, self.node - ) - return [out] - - -class IntTruedivCompiler(CustomCallCompiler): - """Compiler for the `int.__truediv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float, Int - - # Compile `truediv` using float arithmetic - [left, right] = args - [left] = Int.__float__.compile_call( - [left], [], self.dfg, self.graph, self.globals, self.node - ) - [right] = Int.__float__.compile_call( - [right], [], self.dfg, self.graph, self.globals, self.node - ) - [out] = Float.__truediv__.compile_call( - [left, right], [], self.dfg, self.graph, self.globals, self.node - ) - return [out] - - -class FloatBoolCompiler(CustomCallCompiler): - """Compiler for the `float.__bool__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: bool(x) = (x != 0.0) - zero_const = self.graph.add_constant( - float_value(0.0), get_type(self.node), self.dfg.node - ) - zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) - [out] = Float.__ne__.compile_call( - [args[0], zero.out_port(0)], - [], - self.dfg, - self.graph, - self.globals, - self.node, - ) - return [out] - - -class FloatFloordivCompiler(CustomCallCompiler): - """Compiler for the `float.__floordiv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: floordiv(x, y) = floor(truediv(x, y)) - [div] = Float.__truediv__.compile_call( - args, [], self.dfg, self.graph, self.globals, self.node - ) - [floor] = Float.__floor__.compile_call( - [div], [], self.dfg, self.graph, self.globals, self.node - ) - return [floor] - - -class FloatModCompiler(CustomCallCompiler): - """Compiler for the `float.__mod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: mod(x, y) = x - (x // y) * y - [div] = Float.__floordiv__.compile_call( - args, [], self.dfg, self.graph, self.globals, self.node - ) - [mul] = Float.__mul__.compile_call( - [div, args[1]], [], self.dfg, self.graph, self.globals, self.node - ) - [sub] = Float.__sub__.compile_call( - [args[0], mul], [], self.dfg, self.graph, self.globals, self.node - ) - return [sub] - - -class FloatDivmodCompiler(CustomCallCompiler): - """Compiler for the `__divmod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: divmod(x, y) = (div(x, y), mod(x, y)) - [div] = Float.__truediv__.compile_call( - args, [], self.dfg, self.graph, self.globals, self.node - ) - [mod] = Float.__mod__.compile_call( - args, [], self.dfg, self.graph, self.globals, self.node - ) - return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] - - -class MeasureCompiler(CustomCallCompiler): - """Compiler for the `measure` function.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .quantum import quantum_op - - [qubit] = args - measure = self.graph.add_node(quantum_op("Measure"), inputs=args) - self.graph.add_node( - quantum_op("QFree"), inputs=[measure.add_out_port(qubit.ty)] - ) - return [measure.add_out_port(bool_type())] - - -class QAllocCompiler(CustomCallCompiler): - """Compiler for the `qubit` function.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .quantum import quantum_op - - [qubit_ty] = type_to_row(get_type(self.node)) - qalloc = self.graph.add_node(quantum_op("QAlloc"), inputs=args) - qalloc_result = qalloc.add_out_port(qubit_ty) - reset = self.graph.add_node(quantum_op("Reset"), inputs=[qalloc_result]) - return [reset.add_out_port(qubit_ty)] diff --git a/guppylang/prelude/_internal/compiler.py b/guppylang/prelude/_internal/compiler.py new file mode 100644 index 00000000..798031e0 --- /dev/null +++ b/guppylang/prelude/_internal/compiler.py @@ -0,0 +1,218 @@ +import hugr +from hugr import Wire, ops +from hugr import tys as ht +from hugr.std.float import FLOAT_T + +from guppylang.definition.custom import ( + CustomCallCompiler, +) +from guppylang.tys.ty import NumericType + +# Note: Hugr's INT_T is 64bits, but guppy defaults to 32bits +INT_T = NumericType(NumericType.Kind.Int).to_hugr() + + +class NatTruedivCompiler(CustomCallCompiler): + """Compiler for the `nat.__truediv__` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.builtins import Float, Nat + + # Compile `truediv` using float arithmetic + [left, right] = args + [left] = Nat.__float__.compile_call( + [left], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([INT_T], [FLOAT_T]), + ) + [right] = Nat.__float__.compile_call( + [right], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([INT_T], [FLOAT_T]), + ) + [out] = Float.__truediv__.compile_call( + [left, right], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), + ) + return [out] + + +class IntTruedivCompiler(CustomCallCompiler): + """Compiler for the `int.__truediv__` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.builtins import Float, Int + + # Compile `truediv` using float arithmetic + [left, right] = args + [left] = Int.__float__.compile_call( + [left], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([INT_T], [FLOAT_T]), + ) + [right] = Int.__float__.compile_call( + [right], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([INT_T], [FLOAT_T]), + ) + [out] = Float.__truediv__.compile_call( + [left, right], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), + ) + return [out] + + +class FloatBoolCompiler(CustomCallCompiler): + """Compiler for the `float.__bool__` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.builtins import Float + + # We have: bool(x) = (x != 0.0) + zero = self.builder.load(hugr.std.float.FloatVal(0.0)) + [out] = Float.__ne__.compile_call( + [args[0], zero], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [ht.Bool]), + ) + return [out] + + +class FloatFloordivCompiler(CustomCallCompiler): + """Compiler for the `float.__floordiv__` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.builtins import Float + + # We have: floordiv(x, y) = floor(truediv(x, y)) + [div] = Float.__truediv__.compile_call( + args, + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), + ) + [floor] = Float.__floor__.compile_call( + [div], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T], [FLOAT_T]), + ) + return [floor] + + +class FloatModCompiler(CustomCallCompiler): + """Compiler for the `float.__mod__` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.builtins import Float + + # We have: mod(x, y) = x - (x // y) * y + [div] = Float.__floordiv__.compile_call( + args, + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]), + ) + [mul] = Float.__mul__.compile_call( + [div, args[1]], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), + ) + [sub] = Float.__sub__.compile_call( + [args[0], mul], + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), + ) + return [sub] + + +class FloatDivmodCompiler(CustomCallCompiler): + """Compiler for the `__divmod__` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.builtins import Float + + # We have: divmod(x, y) = (div(x, y), mod(x, y)) + [div] = Float.__truediv__.compile_call( + args, + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), + ) + [mod] = Float.__mod__.compile_call( + args, + [], + self.dfg, + self.globals, + self.node, + ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]), + ) + return list(self.builder.add(ops.MakeTuple()(div, mod))) + + +class MeasureCompiler(CustomCallCompiler): + """Compiler for the `measure` function.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.quantum import quantum_op + + [q] = args + [q, bit] = self.builder.add_op( + quantum_op("Measure")(ht.FunctionType([ht.Qubit], [ht.Qubit, ht.Bool]), []), + q, + ) + self.builder.add_op(quantum_op("QFree")(ht.FunctionType([ht.Qubit], []), []), q) + return [bit] + + +class QAllocCompiler(CustomCallCompiler): + """Compiler for the `qubit` function.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + from guppylang.prelude.quantum import quantum_op + + assert not args, "qubit() does not take any arguments" + q = self.builder.add_op( + quantum_op("QAlloc")(ht.FunctionType([], [ht.Qubit]), []) + ) + q = self.builder.add_op( + quantum_op("Reset")(ht.FunctionType([ht.Qubit], [ht.Qubit]), []), q + ) + return [q] diff --git a/guppylang/prelude/_internal/util.py b/guppylang/prelude/_internal/util.py new file mode 100644 index 00000000..92ce6634 --- /dev/null +++ b/guppylang/prelude/_internal/util.py @@ -0,0 +1,220 @@ +"""Utilities for defining builtin functions. + +Note: These custom definitions will be replaced with direct extension operation +definitions from the hugr library. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from hugr import ops +from hugr import tys as ht +from hugr import val as hv + +from guppylang.tys.builtin import list_type +from guppylang.tys.subst import Inst +from guppylang.tys.ty import NumericType, Type + +if TYPE_CHECKING: + from guppylang.tys.arg import Argument + + +@dataclass +class ListVal(hv.ExtensionValue): + """Custom value for a floating point number.""" + + v: list[hv.Value] + ty: ht.Type + + def __init__(self, v: list[hv.Value], elem_ty: Type) -> None: + self.v = v + self.ty = list_type(elem_ty).to_hugr() + + def to_value(self) -> hv.Extension: + # The value list must be serialized at this point, otherwise the + # `Extension` will not be serializable. + vs = [v.to_serial_root() for v in self.v] + return hv.Extension( + name="ListValue", typ=self.ty, val=vs, extensions=["Collections"] + ) + + +def int_arg(n: int = NumericType.INT_WIDTH) -> ht.TypeArg: + """A bounded int type argument.""" + return ht.BoundedNatArg(n=n) + + +def type_arg(idx: int = 0) -> ht.TypeArg: + """A generic type argument.""" + return ht.VariableArg(idx=idx, param=ht.TypeTypeParam(bound=ht.TypeBound.Copyable)) + + +def ltype_arg(idx: int = 0) -> ht.TypeArg: + """A generic linear type argument.""" + return ht.VariableArg(idx=idx, param=ht.TypeTypeParam(bound=ht.TypeBound.Any)) + + +def make_concrete_arg( + arg: ht.TypeArg, + inst: Inst, + variable_remap: dict[int, int] | None = None, +) -> ht.TypeArg: + """Makes a concrete hugr type argument using a guppy instantiation. + + Args: + arg: The hugr type argument to make concrete, containing variable arguments. + inst: The guppy instantiation of the type arguments. + variable_remap: A mapping from the hugr param variable indices to + de Bruijn indices in the guppy type. Defaults to identity. + """ + remap = variable_remap or {} + + if isinstance(arg, ht.VariableArg) and remap.get(arg.idx, arg.idx) < len(inst): + concrete_arg: Argument = inst[remap.get(arg.idx, arg.idx)] + return concrete_arg.to_hugr() + return arg + + +def custom_op( + name: str, + args: list[ht.TypeArg], + *, + ext: str = "guppy.unsupported", + variable_remap: dict[int, int] | None = None, +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Custom hugr operation + + Args: + op_name: The name of the operation. + args: The type arguments of the operation. + ext: The extension of the operation. Defaults to a placeholder extension. + variable_remap: A mapping from the hugr param variable indices to + de Bruijn indices in the guppy type. Defaults to identity. + + Returns: + A function that takes an instantiation of the type arguments as well as + the inferred input and output types and returns a concrete HUGR op. + """ + + def op(ty: ht.FunctionType, inst: Inst) -> ops.DataflowOp: + concrete_args = [make_concrete_arg(arg, inst, variable_remap) for arg in args] + return ops.Custom(extension=ext, signature=ty, name=name, args=concrete_args) + + return op + + +def list_op( + op_name: str, + ext: str = "guppy.unsupported", +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Utility method to create Hugr list operations. + + These ops have exactly one type argument, used to instantiate the list type. + If a the input or output types contain some variable type with index 0, it + is replaced with the type argument when instantiating the op. + + Args: + op_name: The name of the operation. + ext: The extension of the operation. + + Returns: + A function that takes an instantiation of the type arguments and returns + a concrete HUGR op. + """ + return custom_op(op_name, args=[type_arg(0)], ext=ext, variable_remap=None) + + +def linst_op( + op_name: str, + ext: str = "guppy.unsupported", +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Utility method to create linear Hugr list operations. + + These ops have exactly one type argument, used to instantiate the list type. + If a the input or output types contain some variable type with index 0, it + is replaced with the type argument when instantiating the op. + + Args: + op_name: The name of the operation. + ext: The extension of the operation. + + Returns: + A function that takes an instantiation of the type arguments and returns + a concrete HUGR op. + """ + return custom_op(op_name, args=[ltype_arg(0)], ext=ext, variable_remap=None) + + +def float_op( + op_name: str, + ext: str = "arithmetic.float", +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Utility method to create Hugr float arithmetic ops. + + Args: + op_name: The name of the operation. + ext: The extension of the operation. + + Returns: + A function that takes an instantiation of the type arguments and returns + a concrete HUGR op. + """ + return custom_op(op_name, args=[], ext=ext, variable_remap=None) + + +def int_op( + op_name: str, + ext: str = "arithmetic.int", + n_vars: int = 1, +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Utility method to create Hugr integer arithmetic ops. + + Args: + op_name: The name of the operation. + ext: The extension of the operation. + n_vars: The number of type arguments. Defaults to 1. + + Returns: + A function that takes an instantiation of the type arguments and returns + a concrete HUGR op. + """ + # Ideally we'd be able to derive the arguments from the input/output types, + # but the amount of variables does not correlate with the signature for the + # integer ops in hugr :/ + # https://github.com/CQCL/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116 + # + # For now, we just instantiate every type argument to a 64-bit integer. + args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)] + + return custom_op( + op_name, + args=args, + ext=ext, + variable_remap=None, + ) + + +def logic_op( + op_name: str, parametric_size: bool = True, ext: str = "logic" +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Utility method to create Hugr logic ops. + + If `parametric_size` is True, the generated operations has a single argument + encoding the number of boolean inputs to the operation. + + args: + op_name: The name of the operation. + parametric_size: Whether the input count is a parameter to the operation. + ext: The extension of the operation. + + Returns: + A function that takes an instantiation of the type arguments and returns + a concrete HUGR op. + """ + + def op(ty: ht.FunctionType, inst: Inst) -> ops.DataflowOp: + args = [int_arg(len(ty.input))] if parametric_size else [] + return ops.Custom(extension=ext, signature=ty, name=op_name, args=args) + + return op diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index a11a75a1..0be6aef4 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -4,31 +4,37 @@ from typing import Any, Generic, TypeVar -from hugr.serialization import tys - from guppylang.decorator import guppy from guppylang.definition.custom import DefaultCallChecker, NoopCompiler from guppylang.error import GuppyError -from guppylang.hugr_builder.hugr import DummyOp from guppylang.module import GuppyModule -from guppylang.prelude._internal import ( +from guppylang.prelude._internal.checker import ( ArrayLenChecker, CallableChecker, CoercingChecker, DunderChecker, FailingChecker, + ResultChecker, + ReversingChecker, + UnsupportedChecker, +) +from guppylang.prelude._internal.compiler import ( FloatBoolCompiler, FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, IntTruedivCompiler, NatTruedivCompiler, - ResultChecker, - ReversingChecker, - UnsupportedChecker, +) +from guppylang.prelude._internal.util import ( + custom_op, float_op, + int_arg, int_op, + linst_op, + list_op, logic_op, + type_arg, ) from guppylang.tys.builtin import ( array_type_def, @@ -69,28 +75,30 @@ class array(Generic[_T, _n]): @guppy.extend_type(builtins, bool_type_def) class Bool: - @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) + @guppy.hugr_op(builtins, logic_op("And")) def __and__(self: bool, other: bool) -> bool: ... @guppy.custom(builtins, NoopCompiler()) def __bool__(self: bool) -> bool: ... - @guppy.hugr_op(builtins, logic_op("Eq", [tys.TypeArg(tys.BoundedNatArg(n=2))])) + @guppy.hugr_op(builtins, logic_op("Eq")) def __eq__(self: bool, other: bool) -> bool: ... - @guppy.hugr_op(builtins, int_op("ifrombool")) + @guppy.hugr_op(builtins, int_op("ifrombool", n_vars=0)) def __int__(self: bool) -> int: ... - @guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH + @guppy.hugr_op( + builtins, custom_op("ifrombool", args=[int_arg()]) + ) # TODO: Widen to INT_WIDTH def __nat__(self: bool) -> nat: ... @guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False) def __new__(x): ... - @guppy.hugr_op(builtins, logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))])) + @guppy.hugr_op(builtins, logic_op("Or")) def __or__(self: bool, other: bool) -> bool: ... - @guppy.hugr_op(builtins, DummyOp("Xor")) + @guppy.hugr_op(builtins, logic_op("Xor")) def __xor__(self: bool, other: bool) -> bool: ... @@ -105,13 +113,15 @@ def __add__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("iand")) def __and__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(builtins, DummyOp("itobool")) # TODO: Only works with width 1 ints + @guppy.hugr_op( + builtins, custom_op("itobool", args=[]) + ) # TODO: Only works with width 1 ints def __bool__(self: nat) -> bool: ... @guppy.custom(builtins, NoopCompiler()) def __ceil__(self: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("idivmod_u", num_params=2)) + @guppy.hugr_op(builtins, int_op("idivmod_u", n_vars=2)) def __divmod__(self: nat, other: nat) -> tuple[nat, nat]: ... @guppy.hugr_op(builtins, int_op("ieq")) @@ -123,7 +133,7 @@ def __float__(self: nat) -> float: ... @guppy.custom(builtins, NoopCompiler()) def __floor__(self: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("idiv_u", num_params=2)) + @guppy.hugr_op(builtins, int_op("idiv_u", n_vars=2)) def __floordiv__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("ige_u")) @@ -132,7 +142,7 @@ def __ge__(self: nat, other: nat) -> bool: ... @guppy.hugr_op(builtins, int_op("igt_u")) def __gt__(self: nat, other: nat) -> bool: ... - @guppy.hugr_op(builtins, DummyOp("iu_to_s")) # TODO + @guppy.hugr_op(builtins, custom_op("iu_to_s", args=[int_arg()])) # TODO def __int__(self: nat) -> int: ... @guppy.hugr_op(builtins, int_op("inot")) @@ -141,13 +151,13 @@ def __invert__(self: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("ile_u")) def __le__(self: nat, other: nat) -> bool: ... - @guppy.hugr_op(builtins, int_op("ishl", num_params=2)) + @guppy.hugr_op(builtins, int_op("ishl", n_vars=2)) def __lshift__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("ilt_u")) def __lt__(self: nat, other: nat) -> bool: ... - @guppy.hugr_op(builtins, int_op("imod_u", num_params=2)) + @guppy.hugr_op(builtins, int_op("imod_u", n_vars=2)) def __mod__(self: nat, other: nat) -> int: ... @guppy.hugr_op(builtins, int_op("imul")) @@ -168,7 +178,7 @@ def __or__(self: nat, other: nat) -> nat: ... @guppy.custom(builtins, NoopCompiler()) def __pos__(self: nat) -> nat: ... - @guppy.hugr_op(builtins, DummyOp("ipow")) # TODO + @guppy.hugr_op(builtins, custom_op("ipow", args=[int_arg()])) # TODO def __pow__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker()) @@ -177,16 +187,20 @@ def __radd__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker()) def __rand__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("idivmod_u", num_params=2), ReversingChecker()) + @guppy.hugr_op( + builtins, + int_op("idivmod_u", n_vars=2), + ReversingChecker(), + ) def __rdivmod__(self: nat, other: nat) -> tuple[nat, nat]: ... - @guppy.hugr_op(builtins, int_op("idiv_u", num_params=2), ReversingChecker()) + @guppy.hugr_op(builtins, int_op("idiv_u", n_vars=2), ReversingChecker()) def __rfloordiv__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("ishl", num_params=2), ReversingChecker()) + @guppy.hugr_op(builtins, int_op("ishl", n_vars=2), ReversingChecker()) def __rlshift__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("imod_u", num_params=2), ReversingChecker()) + @guppy.hugr_op(builtins, int_op("imod_u", n_vars=2), ReversingChecker()) def __rmod__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker()) @@ -198,13 +212,17 @@ def __ror__(self: nat, other: nat) -> nat: ... @guppy.custom(builtins, NoopCompiler()) def __round__(self: nat) -> nat: ... - @guppy.hugr_op(builtins, DummyOp("ipow"), ReversingChecker()) # TODO + @guppy.hugr_op( + builtins, + custom_op("ipow", args=[int_arg()]), + ReversingChecker(), + ) def __rpow__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("ishr", num_params=2), ReversingChecker()) + @guppy.hugr_op(builtins, int_op("ishr", n_vars=2), ReversingChecker()) def __rrshift__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(builtins, int_op("ishr", num_params=2)) + @guppy.hugr_op(builtins, int_op("ishr", n_vars=2)) def __rshift__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker()) @@ -240,13 +258,13 @@ def __add__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("iand")) def __and__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("itobool")) + @guppy.hugr_op(builtins, int_op("itobool", n_vars=0)) def __bool__(self: int) -> bool: ... @guppy.custom(builtins, NoopCompiler()) def __ceil__(self: int) -> int: ... - @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2)) + @guppy.hugr_op(builtins, int_op("idivmod_s", n_vars=2)) def __divmod__(self: int, other: int) -> tuple[int, int]: ... @guppy.hugr_op(builtins, int_op("ieq")) @@ -258,7 +276,7 @@ def __float__(self: int) -> float: ... @guppy.custom(builtins, NoopCompiler()) def __floor__(self: int) -> int: ... - @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2)) + @guppy.hugr_op(builtins, int_op("idiv_s", n_vars=2)) def __floordiv__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("ige_s")) @@ -276,19 +294,19 @@ def __invert__(self: int) -> int: ... @guppy.hugr_op(builtins, int_op("ile_s")) def __le__(self: int, other: int) -> bool: ... - @guppy.hugr_op(builtins, int_op("ishl", num_params=2)) # TODO: RHS is unsigned + @guppy.hugr_op(builtins, int_op("ishl", n_vars=2)) # TODO: RHS is unsigned def __lshift__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("ilt_s")) def __lt__(self: int, other: int) -> bool: ... - @guppy.hugr_op(builtins, int_op("imod_s", num_params=2)) + @guppy.hugr_op(builtins, int_op("imod_s", n_vars=2)) def __mod__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("imul")) def __mul__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, DummyOp("is_to_u")) # TODO + @guppy.hugr_op(builtins, custom_op("is_to_u", args=[int_arg()])) # TODO def __nat__(self: int) -> nat: ... @guppy.hugr_op(builtins, int_op("ine")) @@ -306,7 +324,7 @@ def __or__(self: int, other: int) -> int: ... @guppy.custom(builtins, NoopCompiler()) def __pos__(self: int) -> int: ... - @guppy.hugr_op(builtins, DummyOp("ipow")) # TODO + @guppy.hugr_op(builtins, custom_op("ipow", args=[int_arg()])) # TODO def __pow__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker()) @@ -315,18 +333,22 @@ def __radd__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker()) def __rand__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker()) + @guppy.hugr_op( + builtins, + int_op("idivmod_s", n_vars=2), + ReversingChecker(), + ) def __rdivmod__(self: int, other: int) -> tuple[int, int]: ... - @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker()) + @guppy.hugr_op(builtins, int_op("idiv_s", n_vars=2), ReversingChecker()) def __rfloordiv__(self: int, other: int) -> int: ... @guppy.hugr_op( - builtins, int_op("ishl", num_params=2), ReversingChecker() + builtins, int_op("ishl", n_vars=2), ReversingChecker() ) # TODO: RHS is unsigned def __rlshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker()) + @guppy.hugr_op(builtins, int_op("imod_s", n_vars=2), ReversingChecker()) def __rmod__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker()) @@ -338,15 +360,19 @@ def __ror__(self: int, other: int) -> int: ... @guppy.custom(builtins, NoopCompiler()) def __round__(self: int) -> int: ... - @guppy.hugr_op(builtins, DummyOp("ipow"), ReversingChecker()) # TODO + @guppy.hugr_op( + builtins, + custom_op("ipow", args=[int_arg()]), + ReversingChecker(), + ) def __rpow__(self: int, other: int) -> int: ... @guppy.hugr_op( - builtins, int_op("ishr", num_params=2), ReversingChecker() + builtins, int_op("ishr", n_vars=2), ReversingChecker() ) # TODO: RHS is unsigned def __rrshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("ishr", num_params=2)) # TODO: RHS is unsigned + @guppy.hugr_op(builtins, int_op("ishr", n_vars=2)) # TODO: RHS is unsigned def __rshift__(self: int, other: int) -> int: ... @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker()) @@ -440,7 +466,7 @@ def __new__(x): ... @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) def __pos__(self: float) -> float: ... - @guppy.hugr_op(builtins, DummyOp("fpow")) # TODO + @guppy.hugr_op(builtins, float_op("fpow", ext="guppylang.unsupported")) # TODO def __pow__(self: float, other: float) -> float: ... @guppy.hugr_op(builtins, float_op("fadd"), ReversingChecker(CoercingChecker())) @@ -460,11 +486,13 @@ def __rmod__(self: float, other: float) -> float: ... @guppy.hugr_op(builtins, float_op("fmul"), ReversingChecker(CoercingChecker())) def __rmul__(self: float, other: float) -> float: ... - @guppy.hugr_op(builtins, DummyOp("fround")) # TODO + @guppy.hugr_op(builtins, float_op("fround", ext="guppylang.unsupported")) # TODO def __round__(self: float) -> float: ... @guppy.hugr_op( - builtins, DummyOp("fpow"), ReversingChecker(DefaultCallChecker()) + builtins, + float_op("fpow", ext="guppylang.unsupported"), + ReversingChecker(DefaultCallChecker()), ) # TODO def __rpow__(self: float, other: float) -> float: ... @@ -488,34 +516,34 @@ def __trunc__(self: float) -> float: ... @guppy.extend_type(builtins, list_type_def) class List: - @guppy.hugr_op(builtins, DummyOp("Concat")) + @guppy.hugr_op(builtins, list_op("Concat")) def __add__(self: list[T], other: list[T]) -> list[T]: ... - @guppy.hugr_op(builtins, DummyOp("IsEmpty")) + @guppy.hugr_op(builtins, list_op("IsEmpty")) def __bool__(self: list[T]) -> bool: ... - @guppy.hugr_op(builtins, DummyOp("Contains")) + @guppy.hugr_op(builtins, list_op("Contains")) def __contains__(self: list[T], el: T) -> bool: ... - @guppy.hugr_op(builtins, DummyOp("AssertEmpty")) + @guppy.hugr_op(builtins, list_op("AssertEmpty")) def __end__(self: list[T]) -> None: ... - @guppy.hugr_op(builtins, DummyOp("Lookup")) + @guppy.hugr_op(builtins, list_op("Lookup")) def __getitem__(self: list[T], idx: int) -> T: ... - @guppy.hugr_op(builtins, DummyOp("IsNonEmpty")) + @guppy.hugr_op(builtins, list_op("IsNonEmpty")) def __hasnext__(self: list[T]) -> tuple[bool, list[T]]: ... @guppy.custom(builtins, NoopCompiler()) def __iter__(self: list[T]) -> list[T]: ... - @guppy.hugr_op(builtins, DummyOp("Length")) + @guppy.hugr_op(builtins, list_op("Length")) def __len__(self: list[T]) -> int: ... - @guppy.hugr_op(builtins, DummyOp("Repeat")) + @guppy.hugr_op(builtins, list_op("Repeat")) def __mul__(self: list[T], other: int) -> list[T]: ... - @guppy.hugr_op(builtins, DummyOp("Pop")) + @guppy.hugr_op(builtins, list_op("Pop")) def __next__(self: list[T]) -> tuple[T, list[T]]: ... @guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) @@ -524,10 +552,10 @@ def __new__(x): ... @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) def __setitem__(self: list[T], idx: int, value: T) -> None: ... - @guppy.hugr_op(builtins, DummyOp("Append"), ReversingChecker()) + @guppy.hugr_op(builtins, list_op("Append"), ReversingChecker()) def __radd__(self: list[T], other: list[T]) -> list[T]: ... - @guppy.hugr_op(builtins, DummyOp("Repeat"), ReversingChecker()) + @guppy.hugr_op(builtins, list_op("Repeat")) def __rmul__(self: list[T], other: int) -> list[T]: ... @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) @@ -539,13 +567,13 @@ def clear(self: list[T]) -> None: ... @guppy.custom(builtins, NoopCompiler()) # Can be noop since lists are immutable def copy(self: list[T]) -> list[T]: ... - @guppy.hugr_op(builtins, DummyOp("Count")) + @guppy.hugr_op(builtins, list_op("Count")) def count(self: list[T], elt: T) -> int: ... @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) def extend(self: list[T], seq: None) -> None: ... - @guppy.hugr_op(builtins, DummyOp("Find")) + @guppy.hugr_op(builtins, list_op("Find")) def index(self: list[T], elt: T) -> int: ... @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) @@ -566,44 +594,44 @@ def sort(self: list[T]) -> None: ... @guppy.extend_type(builtins, linst_type_def) class Linst: - @guppy.hugr_op(builtins, DummyOp("Append")) + @guppy.hugr_op(builtins, linst_op("Append")) def __add__(self: linst[L], other: linst[L]) -> linst[L]: ... - @guppy.hugr_op(builtins, DummyOp("AssertEmpty")) + @guppy.hugr_op(builtins, linst_op("AssertEmpty")) def __end__(self: linst[L]) -> None: ... - @guppy.hugr_op(builtins, DummyOp("IsNonempty")) + @guppy.hugr_op(builtins, linst_op("IsNonempty")) def __hasnext__(self: linst[L]) -> tuple[bool, linst[L]]: ... @guppy.custom(builtins, NoopCompiler()) def __iter__(self: linst[L]) -> linst[L]: ... - @guppy.hugr_op(builtins, DummyOp("Length")) + @guppy.hugr_op(builtins, linst_op("Length")) def __len__(self: linst[L]) -> tuple[int, linst[L]]: ... - @guppy.hugr_op(builtins, DummyOp("Pop")) + @guppy.hugr_op(builtins, linst_op("Pop")) def __next__(self: linst[L]) -> tuple[L, linst[L]]: ... @guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) def __new__(x): ... - @guppy.hugr_op(builtins, DummyOp("Append"), ReversingChecker()) + @guppy.hugr_op(builtins, linst_op("Append"), ReversingChecker()) def __radd__(self: linst[L], other: linst[L]) -> linst[L]: ... - @guppy.hugr_op(builtins, DummyOp("Repeat"), ReversingChecker()) + @guppy.hugr_op(builtins, linst_op("Repeat")) def __rmul__(self: linst[L], other: int) -> linst[L]: ... - @guppy.hugr_op(builtins, DummyOp("Push")) + @guppy.hugr_op(builtins, linst_op("Push")) def append(self: linst[L], elt: L) -> linst[L]: ... - @guppy.hugr_op(builtins, DummyOp("PopAt")) + @guppy.hugr_op(builtins, linst_op("PopAt")) def pop(self: linst[L], idx: int) -> tuple[L, linst[L]]: ... - @guppy.hugr_op(builtins, DummyOp("Reverse")) + @guppy.hugr_op(builtins, linst_op("Reverse")) def reverse(self: linst[L]) -> linst[L]: ... @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) - def sort(self: list[T]) -> None: ... + def sort(self: linst[T]) -> None: ... n = guppy.nat_var(builtins, "n") @@ -611,7 +639,12 @@ def sort(self: list[T]) -> None: ... @guppy.extend_type(builtins, array_type_def) class Array: - @guppy.hugr_op(builtins, DummyOp("ArrayGet")) + @guppy.hugr_op( + builtins, + custom_op( + "ArrayGet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0} + ), + ) def __getitem__(self: array[T, n], idx: int) -> T: ... @guppy.custom(builtins, checker=ArrayLenChecker()) diff --git a/guppylang/prelude/quantum.py b/guppylang/prelude/quantum.py index 95ecd11f..c409caea 100644 --- a/guppylang/prelude/quantum.py +++ b/guppylang/prelude/quantum.py @@ -2,29 +2,41 @@ # mypy: disable-error-code="empty-body, misc" -from hugr.serialization import ops, tys -from hugr.serialization.tys import TypeBound +from collections.abc import Callable + +from hugr import ops +from hugr import tys as ht from guppylang.decorator import guppy -from guppylang.hugr_builder.hugr import UNDEFINED from guppylang.module import GuppyModule -from guppylang.prelude._internal import MeasureCompiler, QAllocCompiler +from guppylang.prelude._internal.compiler import MeasureCompiler, QAllocCompiler +from guppylang.tys.subst import Inst quantum = GuppyModule("quantum") -def quantum_op(op_name: str) -> ops.OpType: - """Utility method to create Hugr quantum ops.""" - return ops.OpType( - ops.CustomOp(extension="quantum.tket2", name=op_name, args=[], parent=UNDEFINED) - ) +def quantum_op( + op_name: str, +) -> Callable[[ht.FunctionType, Inst], ops.DataflowOp]: + """Utility method to create Hugr quantum ops. + + Args: + op_name: The name of the quantum operation. + + Returns: + A function that takes an instantiation of the type arguments and returns + a concrete HUGR op. + """ + + def op(ty: ht.FunctionType, inst: Inst) -> ops.DataflowOp: + return ops.Custom( + name=op_name, extension="quantum.tket2", signature=ty, args=[] + ) + + return op -@guppy.type( - quantum, - tys.Type(tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any)), - linear=True, -) +@guppy.type(quantum, ht.Qubit, linear=True) class qubit: @guppy.custom(quantum, QAllocCompiler()) def __new__() -> "qubit": ... diff --git a/guppylang/tys/arg.py b/guppylang/tys/arg.py index 9189196c..d349f452 100644 --- a/guppylang/tys/arg.py +++ b/guppylang/tys/arg.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, TypeAlias -from hugr.serialization import tys +from hugr import tys as ht from guppylang.error import InternalGuppyError from guppylang.tys.common import ToHugr, Transformable, Transformer, Visitor @@ -23,7 +23,7 @@ @dataclass(frozen=True) -class ArgumentBase(ToHugr[tys.TypeArg], Transformable["Argument"], ABC): +class ArgumentBase(ToHugr[ht.TypeArg], Transformable["Argument"], ABC): """Abstract base class for arguments of parametrized types. For example, in the type `array[int, 42]` we have two arguments `int` and `42`. @@ -47,9 +47,10 @@ def unsolved_vars(self) -> set[ExistentialVar]: """The existential type variables contained in this argument.""" return self.ty.unsolved_vars - def to_hugr(self) -> tys.TypeArg: + def to_hugr(self) -> ht.TypeTypeArg: """Computes the Hugr representation of the argument.""" - return tys.TypeArg(tys.TypeTypeArg(ty=self.ty.to_hugr())) + ty: ht.Type = self.ty.to_hugr() + return ty.type_arg() def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this argument.""" @@ -72,19 +73,17 @@ def unsolved_vars(self) -> set[ExistentialVar]: """The existential const variables contained in this argument.""" return self.const.unsolved_vars - def to_hugr(self) -> tys.TypeArg: + def to_hugr(self) -> ht.TypeArg: """Computes the Hugr representation of this argument.""" from guppylang.tys.ty import NumericType match self.const: case ConstValue(value=v, ty=NumericType(kind=NumericType.Kind.Nat)): assert isinstance(v, int) - return tys.TypeArg(tys.BoundedNatArg(n=v)) + return ht.BoundedNatArg(n=v) case BoundConstVar(idx=idx): - hugr_ty = self.const.ty.to_hugr() - assert isinstance(hugr_ty.root, tys.Opaque) - param = tys.TypeParam(tys.BoundedNatParam(bound=None)) - return tys.TypeArg(tys.VariableArg(idx=idx, cached_decl=param)) + param = ht.BoundedNatParam(upper_bound=None) + return ht.VariableArg(idx=idx, param=param) case ConstValue() | BoundConstVar(): # TODO: Handle other cases besides nats raise NotImplementedError diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 82a838e4..0d5123a1 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal, TypeGuard -from hugr.serialization import tys +from hugr import tys as ht from guppylang.ast_util import AstNode from guppylang.definition.common import DefId @@ -123,31 +123,29 @@ def check_instantiate( return super().check_instantiate(args, globals, loc) -def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: +def _list_to_hugr(args: Sequence[Argument]) -> ht.Type: # Type checker ensures that we get a single arg of kind type [arg] = args assert isinstance(arg, TypeArg) - ty = tys.Opaque( + return ht.Opaque( extension="Collections", id="List", args=[arg.to_hugr()], bound=arg.ty.hugr_bound, ) - return tys.Type(ty) -def _array_to_hugr(args: Sequence[Argument]) -> tys.Type: +def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: # Type checker ensures that we get a two args [ty_arg, len_arg] = args assert isinstance(ty_arg, TypeArg) assert isinstance(len_arg, ConstArg) - ty = tys.Opaque( + return ht.Opaque( extension="prelude", id="array", args=[len_arg.to_hugr(), ty_arg.to_hugr()], bound=ty_arg.ty.hugr_bound, ) - return tys.Type(ty) callable_type_def = _CallableTypeDef(DefId.fresh(), None) @@ -159,7 +157,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> tys.Type: defined_at=None, params=[], always_linear=False, - to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))), + to_hugr=lambda _: ht.Bool, ) nat_type_def = _NumericTypeDef( DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat) diff --git a/guppylang/tys/param.py b/guppylang/tys/param.py index 52b62649..05aa75e8 100644 --- a/guppylang/tys/param.py +++ b/guppylang/tys/param.py @@ -3,8 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, TypeAlias -from hugr.serialization import tys -from hugr.serialization.tys import TypeBound +from hugr import tys as ht from typing_extensions import Self from guppylang.ast_util import AstNode @@ -28,7 +27,7 @@ @dataclass(frozen=True) -class ParameterBase(ToHugr[tys.TypeParam], ABC): +class ParameterBase(ToHugr[ht.TypeParam], ABC): """Abstract base class for parameters used in function and type definitions. For example, when defining a struct type @@ -121,12 +120,10 @@ def to_bound(self, idx: int | None = None) -> Argument: idx = self.idx return TypeArg(BoundTypeVar(self.name, idx, self.can_be_linear)) - def to_hugr(self) -> tys.TypeParam: + def to_hugr(self) -> ht.TypeParam: """Computes the Hugr representation of the parameter.""" - return tys.TypeParam( - tys.TypeTypeParam( - b=tys.TypeBound.Any if self.can_be_linear else TypeBound.Copyable - ) + return ht.TypeTypeParam( + bound=ht.TypeBound.Any if self.can_be_linear else ht.TypeBound.Copyable ) @@ -180,17 +177,16 @@ def to_bound(self, idx: int | None = None) -> Argument: idx = self.idx return ConstArg(BoundConstVar(self.ty, self.name, idx)) - def to_hugr(self) -> tys.TypeParam: + def to_hugr(self) -> ht.TypeParam: """Computes the Hugr representation of the parameter.""" from guppylang.tys.ty import NumericType match self.ty: case NumericType(kind=NumericType.Kind.Nat): - return tys.TypeParam(tys.BoundedNatParam(bound=None)) + return ht.BoundedNatParam(upper_bound=None) case _: - hugr_ty = self.ty.to_hugr() - assert isinstance(hugr_ty.root, tys.Opaque) - return tys.TypeParam(tys.StringParam()) + assert isinstance(self.ty.to_hugr(), ht.Opaque) + return ht.StringParam() def check_all_args( diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 73bc52c7..348b005e 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -5,8 +5,9 @@ from functools import cached_property from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast -from hugr.serialization import tys -from hugr.serialization.tys import TypeBound +import hugr.std.float +import hugr.std.int +from hugr import tys as ht from guppylang.error import InternalGuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg @@ -22,7 +23,7 @@ @dataclass(frozen=True) -class TypeBase(ToHugr[tys.Type], Transformable["Type"], ABC): +class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC): """Abstract base class for all Guppy types. Note that all subclasses are expected to be immutable. @@ -35,7 +36,7 @@ def linear(self) -> bool: @cached_property @abstractmethod - def hugr_bound(self) -> tys.TypeBound: + def hugr_bound(self) -> ht.TypeBound: """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`. This needs to be specified explicitly, since opaque nonlinear types in a Hugr @@ -118,11 +119,11 @@ def unsolved_vars(self) -> set[ExistentialVar]: return set().union(*(arg.unsolved_vars for arg in self.args)) @cached_property - def hugr_bound(self) -> tys.TypeBound: + def hugr_bound(self) -> ht.TypeBound: """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" if self.linear: - return tys.TypeBound.Any - return tys.TypeBound.join( + return ht.TypeBound.Any + return ht.TypeBound.join( *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg)) ) @@ -146,21 +147,21 @@ class BoundTypeVar(TypeBase, BoundVar): linear: bool @cached_property - def hugr_bound(self) -> tys.TypeBound: + def hugr_bound(self) -> ht.TypeBound: """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" if self.linear: - return TypeBound.Any + return ht.TypeBound.Any # We're conservative and don't require equatability for non-linear variables. # This is fine since Guppy doesn't use the equatable feature anyways. - return TypeBound.Copyable + return ht.TypeBound.Copyable def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Variable: """Computes the Hugr representation of the type.""" - return tys.Type(tys.Variable(i=self.idx, b=self.hugr_bound)) + return ht.Variable(idx=self.idx, bound=self.hugr_bound) def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this type.""" @@ -198,7 +199,7 @@ def unsolved_vars(self) -> set[ExistentialVar]: return {self} @cached_property - def hugr_bound(self) -> tys.TypeBound: + def hugr_bound(self) -> ht.TypeBound: """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" raise InternalGuppyError( "Tried to compute bound of unsolved existential type variable" @@ -208,7 +209,7 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Type: """Computes the Hugr representation of the type.""" raise InternalGuppyError( "Tried to convert unsolved existential type variable to Hugr" @@ -228,7 +229,7 @@ class NoneType(TypeBase): """Type of tuples.""" linear: bool = field(default=False, init=False) - hugr_bound: tys.TypeBound = field(default=tys.TypeBound.Copyable, init=False) + hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False) # Flag to avoid turning the type into a row when calling `type_to_row()`. This is # used to make sure that type vars instantiated to Nones are not broken up into @@ -239,9 +240,9 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Tuple: """Computes the Hugr representation of the type.""" - return TupleType([]).to_hugr() + return ht.Tuple() def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this type.""" @@ -276,32 +277,18 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.ExtType: """Computes the Hugr representation of the type.""" match self.kind: case NumericType.Kind.Nat | NumericType.Kind.Int: - return tys.Type( - tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))], - bound=tys.TypeBound.Copyable, - ) - ) + return hugr.std.int.int_t(NumericType.INT_WIDTH) case NumericType.Kind.Float: - return tys.Type( - tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=tys.TypeBound.Copyable, - ) - ) + return hugr.std.float.FLOAT_T @property - def hugr_bound(self) -> tys.TypeBound: + def hugr_bound(self) -> ht.TypeBound: """The Hugr bound of this type, i.e. `Any` or `Copyable`""" - return tys.TypeBound.Copyable + return ht.TypeBound.Copyable def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this type.""" @@ -324,7 +311,7 @@ class FunctionType(ParametrizedTypeBase): args: Sequence[Argument] = field(init=False) linear: bool = field(default=False, init=False) intrinsically_linear: bool = field(default=False, init=False) - hugr_bound: tys.TypeBound = field(default=TypeBound.Copyable, init=False) + hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False) def __init__( self, @@ -351,21 +338,21 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.FunctionType: """Computes the Hugr representation of the type.""" if self.parametrized: raise InternalGuppyError( "Tried to convert parametrised function type to Hugr. Use " "`to_hugr_poly` instead" ) - return tys.Type(self._to_hugr_function_type()) + return self._to_hugr_function_type() - def to_hugr_poly(self) -> tys.PolyFuncType: + def to_hugr_poly(self) -> ht.PolyFuncType: """Computes the Hugr `PolyFuncType` representation of the type.""" func_ty = self._to_hugr_function_type() - return tys.PolyFuncType(params=[p.to_hugr() for p in self.params], body=func_ty) + return ht.PolyFuncType(params=[p.to_hugr() for p in self.params], body=func_ty) - def _to_hugr_function_type(self) -> tys.FunctionType: + def _to_hugr_function_type(self) -> ht.FunctionType: """Helper method to compute the Hugr `FunctionType` representation of the type. The resulting `FunctionType` can then be embedded into a Hugr `Type` or a Hugr @@ -373,7 +360,7 @@ def _to_hugr_function_type(self) -> tys.FunctionType: """ ins = [t.to_hugr() for t in self.inputs] outs = [t.to_hugr() for t in type_to_row(self.output)] - return tys.FunctionType(input=ins, output=outs) + return ht.FunctionType(input=ins, output=outs) def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this type.""" @@ -448,14 +435,9 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Tuple: """Computes the Hugr representation of the type.""" - # Tuples are encoded as a unary sum. Note that we need to make a copy of this - # tuple with `preserve=False` to ensure that it can be broken up into a row (if - # this tuple was created by instantiating a type variable, it is still - # represented as a *row* sum). - tuple_ty = TupleType(self.element_types, preserve=False) - return SumType([tuple_ty]).to_hugr() + return ht.Tuple(*row_to_hugr(self.element_types)) def transform(self, transformer: Transformer) -> "Type": """Accepts a transformer on this type.""" @@ -488,15 +470,15 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Sum: """Computes the Hugr representation of the type.""" rows = [type_to_row(ty) for ty in self.element_types] - sum_inner: tys.UnitSum | tys.GeneralSum if all(len(row) == 0 for row in rows): - sum_inner = tys.UnitSum(size=len(rows)) + return ht.UnitSum(size=len(rows)) + elif len(rows) == 1: + return ht.Tuple(*row_to_hugr(rows[0])) else: - sum_inner = tys.GeneralSum(rows=rows_to_hugr(rows)) - return tys.Type(tys.SumType(sum_inner)) + return ht.Sum(variant_rows=rows_to_hugr(rows)) def transform(self, transformer: Transformer) -> "Type": """Accepts a transformer on this type.""" @@ -521,7 +503,7 @@ def intrinsically_linear(self) -> bool: return self.defn.always_linear @property - def hugr_bound(self) -> tys.TypeBound: + def hugr_bound(self) -> ht.TypeBound: """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" if self.defn.bound is not None: return self.defn.bound @@ -531,7 +513,7 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Type: """Computes the Hugr representation of the type.""" return self.defn.to_hugr(self.args) @@ -571,9 +553,9 @@ def cast(self) -> "Type": """Casts an implementor of `TypeBase` into a `Type`.""" return self - def to_hugr(self) -> tys.Type: + def to_hugr(self) -> ht.Tuple: """Computes the Hugr representation of the type.""" - return TupleType([f.ty for f in self.fields]).to_hugr() + return ht.Tuple(*(f.ty.to_hugr() for f in self.fields)) def transform(self, transformer: Transformer) -> "Type": """Accepts a transformer on this type.""" @@ -623,12 +605,12 @@ def type_to_row(ty: Type) -> TypeRow: return [ty] -def row_to_hugr(row: TypeRow) -> tys.TypeRow: +def row_to_hugr(row: TypeRow) -> ht.TypeRow: """Computes the Hugr representation of a type row.""" return [ty.to_hugr() for ty in row] -def rows_to_hugr(rows: Sequence[TypeRow]) -> list[tys.TypeRow]: +def rows_to_hugr(rows: Sequence[TypeRow]) -> list[ht.TypeRow]: """Computes the Hugr representation of a sequence of rows.""" return [row_to_hugr(row) for row in rows] diff --git a/poetry.lock b/poetry.lock index d1fc4d0e..64c3355d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -689,56 +689,63 @@ files = [ [[package]] name = "numpy" -version = "2.0.1" +version = "2.1.0" description = "Fundamental package for array computing in Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "numpy-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0fbb536eac80e27a2793ffd787895242b7f18ef792563d742c2d673bfcb75134"}, - {file = "numpy-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:69ff563d43c69b1baba77af455dd0a839df8d25e8590e79c90fcbe1499ebde42"}, - {file = "numpy-2.0.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1b902ce0e0a5bb7704556a217c4f63a7974f8f43e090aff03fcf262e0b135e02"}, - {file = "numpy-2.0.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:f1659887361a7151f89e79b276ed8dff3d75877df906328f14d8bb40bb4f5101"}, - {file = "numpy-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4658c398d65d1b25e1760de3157011a80375da861709abd7cef3bad65d6543f9"}, - {file = "numpy-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4127d4303b9ac9f94ca0441138acead39928938660ca58329fe156f84b9f3015"}, - {file = "numpy-2.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e5eeca8067ad04bc8a2a8731183d51d7cbaac66d86085d5f4766ee6bf19c7f87"}, - {file = "numpy-2.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9adbd9bb520c866e1bfd7e10e1880a1f7749f1f6e5017686a5fbb9b72cf69f82"}, - {file = "numpy-2.0.1-cp310-cp310-win32.whl", hash = "sha256:7b9853803278db3bdcc6cd5beca37815b133e9e77ff3d4733c247414e78eb8d1"}, - {file = "numpy-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:81b0893a39bc5b865b8bf89e9ad7807e16717f19868e9d234bdaf9b1f1393868"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75b4e316c5902d8163ef9d423b1c3f2f6252226d1aa5cd8a0a03a7d01ffc6268"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6e4eeb6eb2fced786e32e6d8df9e755ce5be920d17f7ce00bc38fcde8ccdbf9e"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a1e01dcaab205fbece13c1410253a9eea1b1c9b61d237b6fa59bcc46e8e89343"}, - {file = "numpy-2.0.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a8fc2de81ad835d999113ddf87d1ea2b0f4704cbd947c948d2f5513deafe5a7b"}, - {file = "numpy-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a3d94942c331dd4e0e1147f7a8699a4aa47dffc11bf8a1523c12af8b2e91bbe"}, - {file = "numpy-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15eb4eca47d36ec3f78cde0a3a2ee24cf05ca7396ef808dda2c0ddad7c2bde67"}, - {file = "numpy-2.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b83e16a5511d1b1f8a88cbabb1a6f6a499f82c062a4251892d9ad5d609863fb7"}, - {file = "numpy-2.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f87fec1f9bc1efd23f4227becff04bd0e979e23ca50cc92ec88b38489db3b55"}, - {file = "numpy-2.0.1-cp311-cp311-win32.whl", hash = "sha256:36d3a9405fd7c511804dc56fc32974fa5533bdeb3cd1604d6b8ff1d292b819c4"}, - {file = "numpy-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:08458fbf403bff5e2b45f08eda195d4b0c9b35682311da5a5a0a0925b11b9bd8"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bf4e6f4a2a2e26655717a1983ef6324f2664d7011f6ef7482e8c0b3d51e82ac"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7d6fddc5fe258d3328cd8e3d7d3e02234c5d70e01ebe377a6ab92adb14039cb4"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5daab361be6ddeb299a918a7c0864fa8618af66019138263247af405018b04e1"}, - {file = "numpy-2.0.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:ea2326a4dca88e4a274ba3a4405eb6c6467d3ffbd8c7d38632502eaae3820587"}, - {file = "numpy-2.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:529af13c5f4b7a932fb0e1911d3a75da204eff023ee5e0e79c1751564221a5c8"}, - {file = "numpy-2.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6790654cb13eab303d8402354fabd47472b24635700f631f041bd0b65e37298a"}, - {file = "numpy-2.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cbab9fc9c391700e3e1287666dfd82d8666d10e69a6c4a09ab97574c0b7ee0a7"}, - {file = "numpy-2.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99d0d92a5e3613c33a5f01db206a33f8fdf3d71f2912b0de1739894668b7a93b"}, - {file = "numpy-2.0.1-cp312-cp312-win32.whl", hash = "sha256:173a00b9995f73b79eb0191129f2455f1e34c203f559dd118636858cc452a1bf"}, - {file = "numpy-2.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:bb2124fdc6e62baae159ebcfa368708867eb56806804d005860b6007388df171"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc085b28d62ff4009364e7ca34b80a9a080cbd97c2c0630bb5f7f770dae9414"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8fae4ebbf95a179c1156fab0b142b74e4ba4204c87bde8d3d8b6f9c34c5825ef"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:72dc22e9ec8f6eaa206deb1b1355eb2e253899d7347f5e2fae5f0af613741d06"}, - {file = "numpy-2.0.1-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:ec87f5f8aca726117a1c9b7083e7656a9d0d606eec7299cc067bb83d26f16e0c"}, - {file = "numpy-2.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f682ea61a88479d9498bf2091fdcd722b090724b08b31d63e022adc063bad59"}, - {file = "numpy-2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8efc84f01c1cd7e34b3fb310183e72fcdf55293ee736d679b6d35b35d80bba26"}, - {file = "numpy-2.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3fdabe3e2a52bc4eff8dc7a5044342f8bd9f11ef0934fcd3289a788c0eb10018"}, - {file = "numpy-2.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:24a0e1befbfa14615b49ba9659d3d8818a0f4d8a1c5822af8696706fbda7310c"}, - {file = "numpy-2.0.1-cp39-cp39-win32.whl", hash = "sha256:f9cf5ea551aec449206954b075db819f52adc1638d46a6738253a712d553c7b4"}, - {file = "numpy-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:e9e81fa9017eaa416c056e5d9e71be93d05e2c3c2ab308d23307a8bc4443c368"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:61728fba1e464f789b11deb78a57805c70b2ed02343560456190d0501ba37b0f"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-macosx_14_0_x86_64.whl", hash = "sha256:12f5d865d60fb9734e60a60f1d5afa6d962d8d4467c120a1c0cda6eb2964437d"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eacf3291e263d5a67d8c1a581a8ebbcfd6447204ef58828caf69a5e3e8c75990"}, - {file = "numpy-2.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2c3a346ae20cfd80b6cfd3e60dc179963ef2ea58da5ec074fd3d9e7a1e7ba97f"}, - {file = "numpy-2.0.1.tar.gz", hash = "sha256:485b87235796410c3519a699cfe1faab097e509e90ebb05dcd098db2ae87e7b3"}, + {file = "numpy-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8"}, + {file = "numpy-2.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911"}, + {file = "numpy-2.1.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673"}, + {file = "numpy-2.1.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b"}, + {file = "numpy-2.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b"}, + {file = "numpy-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f"}, + {file = "numpy-2.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84"}, + {file = "numpy-2.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33"}, + {file = "numpy-2.1.0-cp310-cp310-win32.whl", hash = "sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211"}, + {file = "numpy-2.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa"}, + {file = "numpy-2.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce"}, + {file = "numpy-2.1.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1"}, + {file = "numpy-2.1.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e"}, + {file = "numpy-2.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb"}, + {file = "numpy-2.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3"}, + {file = "numpy-2.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36"}, + {file = "numpy-2.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd"}, + {file = "numpy-2.1.0-cp311-cp311-win32.whl", hash = "sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e"}, + {file = "numpy-2.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b"}, + {file = "numpy-2.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736"}, + {file = "numpy-2.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529"}, + {file = "numpy-2.1.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1"}, + {file = "numpy-2.1.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8"}, + {file = "numpy-2.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745"}, + {file = "numpy-2.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111"}, + {file = "numpy-2.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0"}, + {file = "numpy-2.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574"}, + {file = "numpy-2.1.0-cp312-cp312-win32.whl", hash = "sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02"}, + {file = "numpy-2.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc"}, + {file = "numpy-2.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667"}, + {file = "numpy-2.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e"}, + {file = "numpy-2.1.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33"}, + {file = "numpy-2.1.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b"}, + {file = "numpy-2.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195"}, + {file = "numpy-2.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977"}, + {file = "numpy-2.1.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1"}, + {file = "numpy-2.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62"}, + {file = "numpy-2.1.0-cp313-cp313-win32.whl", hash = "sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324"}, + {file = "numpy-2.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5"}, + {file = "numpy-2.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06"}, + {file = "numpy-2.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883"}, + {file = "numpy-2.1.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300"}, + {file = "numpy-2.1.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d"}, + {file = "numpy-2.1.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd"}, + {file = "numpy-2.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6"}, + {file = "numpy-2.1.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a"}, + {file = "numpy-2.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb"}, + {file = "numpy-2.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b"}, + {file = "numpy-2.1.0-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d"}, + {file = "numpy-2.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575"}, + {file = "numpy-2.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2"}, + {file = "numpy-2.1.0.tar.gz", hash = "sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2"}, ] [[package]] @@ -1165,29 +1172,29 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.6.0" +version = "0.6.1" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.0-py3-none-linux_armv6l.whl", hash = "sha256:92dcce923e5df265781e5fc76f9a1edad52201a7aafe56e586b90988d5239013"}, - {file = "ruff-0.6.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:31b90ff9dc79ed476c04e957ba7e2b95c3fceb76148f2079d0d68a908d2cfae7"}, - {file = "ruff-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6d834a9ec9f8287dd6c3297058b3a265ed6b59233db22593379ee38ebc4b9768"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2089267692696aba342179471831a085043f218706e642564812145df8b8d0d"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aa62b423ee4bbd8765f2c1dbe8f6aac203e0583993a91453dc0a449d465c84da"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7344e1a964b16b1137ea361d6516ce4ee61a0403fa94252a1913ecc1311adcae"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:487f3a35c3f33bf82be212ce15dc6278ea854e35573a3f809442f73bec8b2760"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75db409984077a793cf344d499165298a6f65449e905747ac65983b12e3e64b1"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84908bd603533ecf1db456d8fc2665d1f4335d722e84bc871d3bbd2d1116c272"}, - {file = "ruff-0.6.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f1749a0aef3ec41ed91a0e2127a6ae97d2e2853af16dbd4f3c00d7a3af726c5"}, - {file = "ruff-0.6.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:016fea751e2bcfbbd2f8cb19b97b37b3fd33148e4df45b526e87096f4e17354f"}, - {file = "ruff-0.6.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6ae80f141b53b2e36e230017e64f5ea2def18fac14334ffceaae1b780d70c4f7"}, - {file = "ruff-0.6.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:eaaaf33ea4b3f63fd264d6a6f4a73fa224bbfda4b438ffea59a5340f4afa2bb5"}, - {file = "ruff-0.6.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7667ddd1fc688150a7ca4137140867584c63309695a30016880caf20831503a0"}, - {file = "ruff-0.6.0-py3-none-win32.whl", hash = "sha256:ae48365aae60d40865a412356f8c6f2c0be1c928591168111eaf07eaefa6bea3"}, - {file = "ruff-0.6.0-py3-none-win_amd64.whl", hash = "sha256:774032b507c96f0c803c8237ce7d2ef3934df208a09c40fa809c2931f957fe5e"}, - {file = "ruff-0.6.0-py3-none-win_arm64.whl", hash = "sha256:a5366e8c3ae6b2dc32821749b532606c42e609a99b0ae1472cf601da931a048c"}, - {file = "ruff-0.6.0.tar.gz", hash = "sha256:272a81830f68f9bd19d49eaf7fa01a5545c5a2e86f32a9935bb0e4bb9a1db5b8"}, + {file = "ruff-0.6.1-py3-none-linux_armv6l.whl", hash = "sha256:b4bb7de6a24169dc023f992718a9417380301b0c2da0fe85919f47264fb8add9"}, + {file = "ruff-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:45efaae53b360c81043e311cdec8a7696420b3d3e8935202c2846e7a97d4edae"}, + {file = "ruff-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bc60c7d71b732c8fa73cf995efc0c836a2fd8b9810e115be8babb24ae87e0850"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c7477c3b9da822e2db0b4e0b59e61b8a23e87886e727b327e7dcaf06213c5cf"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a0af7ab3f86e3dc9f157a928e08e26c4b40707d0612b01cd577cc84b8905cc9"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:392688dbb50fecf1bf7126731c90c11a9df1c3a4cdc3f481b53e851da5634fa5"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5278d3e095ccc8c30430bcc9bc550f778790acc211865520f3041910a28d0024"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe6d5f65d6f276ee7a0fc50a0cecaccb362d30ef98a110f99cac1c7872df2f18"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e0dd11e2ae553ee5c92a81731d88a9883af8db7408db47fc81887c1f8b672e"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d812615525a34ecfc07fd93f906ef5b93656be01dfae9a819e31caa6cfe758a1"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faaa4060f4064c3b7aaaa27328080c932fa142786f8142aff095b42b6a2eb631"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:99d7ae0df47c62729d58765c593ea54c2546d5de213f2af2a19442d50a10cec9"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9eb18dfd7b613eec000e3738b3f0e4398bf0153cb80bfa3e351b3c1c2f6d7b15"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c62bc04c6723a81e25e71715aa59489f15034d69bf641df88cb38bdc32fd1dbb"}, + {file = "ruff-0.6.1-py3-none-win32.whl", hash = "sha256:9fb4c4e8b83f19c9477a8745e56d2eeef07a7ff50b68a6998f7d9e2e3887bdc4"}, + {file = "ruff-0.6.1-py3-none-win_amd64.whl", hash = "sha256:c2ebfc8f51ef4aca05dad4552bbcf6fe8d1f75b2f6af546cc47cc1c1ca916b5b"}, + {file = "ruff-0.6.1-py3-none-win_arm64.whl", hash = "sha256:3bc81074971b0ffad1bd0c52284b22411f02a11a012082a76ac6da153536e014"}, + {file = "ruff-0.6.1.tar.gz", hash = "sha256:af3ffd8c6563acb8848d33cd19a69b9bfe943667f0419ca083f8ebe4224a3436"}, ] [[package]] diff --git a/tests/error/util.py b/tests/error/util.py index 1eb6964f..a2662615 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -1,8 +1,8 @@ import importlib.util import pathlib import pytest -from hugr.serialization import tys -from hugr.serialization.tys import TypeBound +from hugr import tys +from hugr.tys import TypeBound from guppylang.error import GuppyError from guppylang.module import GuppyModule @@ -29,7 +29,7 @@ def run_error_test(file, capsys): @decorator.guppy.type( - util, tys.Type(tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)) + util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable) ) class NonBool: pass diff --git a/tests/hugr/__init__.py b/tests/hugr/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py deleted file mode 100644 index a6a5cc65..00000000 --- a/tests/hugr/test_dummy_nodes.py +++ /dev/null @@ -1,42 +0,0 @@ -from hugr.serialization import ops - -from guppylang.tys.builtin import bool_type -from guppylang.tys.ty import FunctionType, TupleType -from guppylang.hugr_builder.hugr import Hugr, DummyOp - - -def test_single_dummy(): - g = Hugr() - defn = g.add_def(FunctionType([bool_type()], bool_type()), g.root, "test") - dfg = g.add_dfg(defn) - inp = g.add_input([bool_type()], dfg).out_port(0) - dummy = g.add_node( - DummyOp("dummy"), inputs=[inp], output_types=[bool_type()], parent=dfg - ) - g.add_output([dummy.out_port(0)], parent=dfg) - - g.remove_dummy_nodes() - [decl] = [n for n in g.nodes() if isinstance(n.op.root, ops.FuncDecl)] - assert decl.op.root.name == "dummy" - - -def test_unique_names(): - g = Hugr() - defn = g.add_def( - FunctionType([bool_type()], TupleType([bool_type(), bool_type()])), - g.root, - "test", - ) - dfg = g.add_dfg(defn) - inp = g.add_input([bool_type()], dfg).out_port(0) - dummy1 = g.add_node( - DummyOp("dummy"), inputs=[inp], output_types=[bool_type()], parent=dfg - ) - dummy2 = g.add_node( - DummyOp("dummy"), inputs=[inp], output_types=[bool_type()], parent=dfg - ) - g.add_output([dummy1.out_port(0), dummy2.out_port(0)], parent=dfg) - - g.remove_dummy_nodes() - [decl1, decl2] = [n for n in g.nodes() if isinstance(n.op.root, ops.FuncDecl)] - assert {decl1.op.root.name, decl2.op.root.name} == {"dummy", "dummy$1"} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 76aaad05..b27bf353 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,4 +1,4 @@ -from guppylang.hugr_builder.hugr import Hugr +from hugr import Hugr from pathlib import Path import pytest @@ -22,9 +22,9 @@ def validate_json(hugr: str): except ImportError: pytest.skip("Skipping validation") - def validate_impl(hugr, name=None): + def validate_impl(hugr: Hugr, name=None): # Validate via the json encoding - js = hugr.serialize() + js = hugr.to_json() validate_json(js) if export_test_cases_dir: @@ -48,7 +48,7 @@ def f(hugr: Hugr, expected: int, fn_name: str = "main"): if not hasattr(execute_llvm, "run_int_function"): pytest.skip("Skipping llvm execution") - hugr_json: str = hugr.serialize() + hugr_json: str = hugr.to_json() res = execute_llvm.run_int_function(hugr_json, fn_name) if res != expected: raise LLVMException( diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 2c62b66b..49ac3fc7 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -1,8 +1,8 @@ -from hugr.serialization import ops +from hugr import ops +from hugr.std.int import IntVal from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.prelude._internal import ConstInt from guppylang.prelude.builtins import array from tests.util import compile_guppy @@ -17,14 +17,10 @@ def main(xs: array[float, 42]) -> int: hg = module.compile() validate(hg) - [val] = [ - node.op.root.v.root - for node in hg.nodes() - if isinstance(node.op.root, ops.Const) - ] - assert isinstance(val, ops.ExtensionValue) - assert isinstance(val.value.v, ConstInt) - assert val.value.v.value == 42 + [val] = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)] + assert isinstance(val, ops.Const) + assert isinstance(val.val, IntVal) + assert val.val.v == 42 def test_index(validate): diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index c3566b90..01e90b76 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -1,4 +1,4 @@ -from hugr.serialization import ops +from hugr import ops from guppylang.decorator import guppy from guppylang.module import GuppyModule @@ -23,7 +23,7 @@ def void() -> None: def test_copy(validate): @compile_guppy - def copy(x: int) -> (int, int): + def copy(x: int) -> tuple[int, int]: return x, x validate(copy) @@ -69,8 +69,9 @@ def func_name() -> None: return [def_op] = [ - n.op.root for n in func_name.nodes() if isinstance(n.op.root, ops.FuncDefn) + data.op for n, data in func_name.nodes() if isinstance(data.op, ops.FuncDefn) ] + assert isinstance(def_op, ops.FuncDefn) assert def_op.name == "func_name" @@ -80,11 +81,11 @@ def test_func_decl_name(): @guppy.declare(module) def func_name() -> None: ... + hugr = module.compile() [def_op] = [ - n.op.root - for n in module.compile().nodes() - if isinstance(n.op.root, ops.FuncDecl) + data.op for n, data in hugr.nodes() if isinstance(data.op, ops.FuncDecl) ] + assert isinstance(def_op, ops.FuncDecl) assert def_op.name == "func_name" diff --git a/tests/integration/test_comprehension.py b/tests/integration/test_comprehension.py index 956a628e..6dddcbeb 100644 --- a/tests/integration/test_comprehension.py +++ b/tests/integration/test_comprehension.py @@ -1,4 +1,4 @@ -from hugr.serialization import tys +from hugr import tys from guppylang.decorator import guppy from guppylang.module import GuppyModule @@ -270,11 +270,7 @@ def test_nonlinear_next_linear_iter(validate): @guppy.type( module, - tys.Type( - tys.Opaque( - extension="prelude", id="qubit", args=[], bound=tys.TypeBound.Any - ) - ), + tys.Opaque(extension="prelude", id="qubit", args=[], bound=tys.TypeBound.Any), linear=True, ) class MyIter: diff --git a/tests/integration/test_extern.py b/tests/integration/test_extern.py index cfe2ff3f..c9591d5a 100644 --- a/tests/integration/test_extern.py +++ b/tests/integration/test_extern.py @@ -1,4 +1,4 @@ -from hugr.serialization import ops +from hugr import ops, val from guppylang.decorator import guppy from guppylang.module import GuppyModule @@ -16,9 +16,9 @@ def main() -> float: hg = module.compile() validate(hg) - [c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)] - assert isinstance(c.v.root, ops.ExtensionValue) - assert c.v.root.value.v["symbol"] == "ext" + [c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] + assert isinstance(c.val, val.Extension) + assert c.val.val["symbol"] == "ext" def test_extern_alt_symbol(validate): @@ -33,9 +33,9 @@ def main() -> int: hg = module.compile() validate(hg) - [c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)] - assert isinstance(c.v.root, ops.ExtensionValue) - assert c.v.root.value.v["symbol"] == "foo" + [c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] + assert isinstance(c.val, val.Extension) + assert c.val.val["symbol"] == "foo" def test_extern_tuple(validate): diff --git a/tests/integration/test_nested.py b/tests/integration/test_nested.py index aac788ba..fea6e9cf 100644 --- a/tests/integration/test_nested.py +++ b/tests/integration/test_nested.py @@ -129,6 +129,20 @@ def bar() -> int: validate(foo) +def test_capture_fn(validate): + @compile_guppy + def foo() -> bool: + def f(x: bool) -> bool: + return x + + def g(b: bool) -> bool: + return f(b) + + return g(True) + + validate(foo) + + def test_capture_cfg(validate): @compile_guppy def foo(x: int) -> int: diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 5edb4d79..4bf961d1 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -2,9 +2,10 @@ import pytest +from hugr import Wire + from guppylang.decorator import guppy from guppylang.definition.custom import CustomCallCompiler -from guppylang.hugr_builder.hugr import OutPortV from guppylang.module import GuppyModule from guppylang.prelude.builtins import array from guppylang.prelude.quantum import qubit @@ -278,7 +279,7 @@ def main() -> None: def test_custom_higher_order(): class CustomCompiler(CustomCallCompiler): - def compile(self, args: list[OutPortV]) -> list[OutPortV]: + def compile(self, args: list[Wire]) -> list[Wire]: return args module = GuppyModule("test") diff --git a/tests/integration/test_struct.py b/tests/integration/test_struct.py index b465c826..a048a274 100644 --- a/tests/integration/test_struct.py +++ b/tests/integration/test_struct.py @@ -143,3 +143,21 @@ def foo() -> MyStruct: return s validate(module.compile()) + + +def test_field_access_and_drop(validate): + module = GuppyModule("module") + + @guppy.struct(module) + class MyStruct: + x: int + y: float + z: tuple[int, int] + + @guppy(module) + def foo() -> float: + # Access a field of an unnamed struct, + # dropping all the other fields + return MyStruct(42, 3.14, (1, 2)).y + + validate(module.compile()) diff --git a/tests/util.py b/tests/util.py index 0ad04cb4..a6b428c4 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING, Any +from hugr import Hugr + import guppylang from guppylang.definition.function import RawFunctionDef -from guppylang.hugr_builder.hugr import Hugr from guppylang.module import GuppyModule if TYPE_CHECKING: @@ -33,7 +34,7 @@ def dump_llvm(hugr: Hugr): try: from execute_llvm import compile_module_to_string - hugr_json = hugr.serialize() + hugr_json = hugr.to_json() llvm_module = compile_module_to_string(hugr_json) print(llvm_module) # noqa: T201 @@ -52,5 +53,5 @@ def guppy_to_circuit(guppy_func: RawFunctionDef) -> "Tk2Circuit": hugr = module.compile() assert hugr is not None, "Module must be compilable" - json = hugr.to_raw().to_json() + json = hugr.to_json() return Tk2Circuit.from_guppy_json(json, guppy_func.name)