From b701c0686900f890985099b9ea881b38ff0aaada Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 5 Dec 2023 09:28:33 +0000 Subject: [PATCH] refactor: Separate type checking from compilation code (#57) * Type checking code moved to guppy/checker (#58) * Graph generation code moved to guppy/compiler (#59) * Unified extension system with regular Guppy modules and added new buitlins module (#60) --- .github/workflows/pull-request.yaml | 2 +- guppy/__init__.py | 2 +- guppy/ast_util.py | 90 ++- guppy/cfg/bb.py | 95 +-- guppy/cfg/builder.py | 33 +- guppy/cfg/cfg.py | 323 +------- guppy/checker/__init__.py | 0 guppy/checker/cfg_checker.py | 223 ++++++ guppy/checker/core.py | 106 +++ guppy/checker/expr_checker.py | 364 +++++++++ guppy/checker/func_checker.py | 189 +++++ guppy/checker/stmt_checker.py | 140 ++++ guppy/compiler.py | 207 ----- guppy/compiler/__init__.py | 0 guppy/compiler/cfg_compiler.py | 173 +++++ guppy/compiler/core.py | 119 +++ guppy/compiler/expr_compiler.py | 116 +++ guppy/compiler/func_compiler.py | 113 +++ guppy/compiler/stmt_compiler.py | 96 +++ guppy/compiler_base.py | 251 ------ guppy/custom.py | 235 ++++++ guppy/declared.py | 66 ++ guppy/decorator.py | 195 +++++ guppy/error.py | 102 ++- guppy/expression.py | 310 -------- guppy/extension.py | 386 --------- guppy/function.py | 240 ------ guppy/{guppy_types.py => gtypes.py} | 114 ++- guppy/hugr/hugr.py | 26 +- guppy/module.py | 222 ++++++ guppy/nodes.py | 78 ++ guppy/prelude/_internal.py | 291 +++++++ guppy/prelude/boolean.py | 47 -- guppy/prelude/builtin.py | 230 ------ guppy/prelude/builtins.py | 731 ++++++++++++++++++ guppy/prelude/float.py | 260 ------- guppy/prelude/integer.py | 287 ------- guppy/prelude/quantum.py | 41 +- guppy/statement.py | 177 ----- tests/error/linear_errors/branch_use.py | 10 +- tests/error/linear_errors/break_unused.py | 10 +- tests/error/linear_errors/continue_unused.py | 10 +- tests/error/linear_errors/copy.py | 6 +- tests/error/linear_errors/if_both_unused.py | 8 +- .../linear_errors/if_both_unused_reassign.py | 8 +- tests/error/linear_errors/reassign_unused.py | 8 +- .../linear_errors/reassign_unused_tuple.py | 8 +- tests/error/linear_errors/unused.py | 6 +- tests/error/linear_errors/unused_expr.err | 7 + tests/error/linear_errors/unused_expr.py | 16 + .../error/linear_errors/unused_same_block.py | 6 +- .../nested_errors/different_types_if.err | 2 +- tests/error/nested_errors/linear_capture.py | 6 +- tests/error/type_errors/and_not_bool_left.py | 6 +- tests/error/type_errors/and_not_bool_right.py | 6 +- tests/error/type_errors/fun_ty_mismatch_1.err | 2 +- tests/error/type_errors/fun_ty_mismatch_2.err | 2 +- tests/error/type_errors/if_expr_not_bool.py | 6 +- tests/error/type_errors/if_not_bool.py | 6 +- tests/error/type_errors/not_not_bool.py | 6 +- tests/error/type_errors/or_not_bool_left.py | 6 +- tests/error/type_errors/or_not_bool_right.py | 6 +- tests/error/type_errors/return_mismatch.err | 2 +- tests/error/type_errors/while_not_bool.py | 6 +- tests/error/util.py | 24 +- tests/hugr/test_dummy_nodes.py | 18 +- tests/hugr/test_ports.py | 2 +- tests/integration/test_arithmetic.py | 4 +- tests/integration/test_basic.py | 9 +- tests/integration/test_call.py | 9 +- tests/integration/test_functional.py | 2 +- tests/integration/test_higher_order.py | 3 +- tests/integration/test_if.py | 3 +- tests/integration/test_linear.py | 49 +- tests/integration/test_nested.py | 2 +- tests/integration/test_programs.py | 5 +- tests/integration/test_unused.py | 2 +- tests/integration/test_while.py | 2 +- 78 files changed, 4011 insertions(+), 2968 deletions(-) create mode 100644 guppy/checker/__init__.py create mode 100644 guppy/checker/cfg_checker.py create mode 100644 guppy/checker/core.py create mode 100644 guppy/checker/expr_checker.py create mode 100644 guppy/checker/func_checker.py create mode 100644 guppy/checker/stmt_checker.py delete mode 100644 guppy/compiler.py create mode 100644 guppy/compiler/__init__.py create mode 100644 guppy/compiler/cfg_compiler.py create mode 100644 guppy/compiler/core.py create mode 100644 guppy/compiler/expr_compiler.py create mode 100644 guppy/compiler/func_compiler.py create mode 100644 guppy/compiler/stmt_compiler.py delete mode 100644 guppy/compiler_base.py create mode 100644 guppy/custom.py create mode 100644 guppy/declared.py create mode 100644 guppy/decorator.py delete mode 100644 guppy/expression.py delete mode 100644 guppy/extension.py delete mode 100644 guppy/function.py rename guppy/{guppy_types.py => gtypes.py} (63%) create mode 100644 guppy/module.py create mode 100644 guppy/nodes.py create mode 100644 guppy/prelude/_internal.py delete mode 100644 guppy/prelude/boolean.py delete mode 100644 guppy/prelude/builtin.py create mode 100644 guppy/prelude/builtins.py delete mode 100644 guppy/prelude/float.py delete mode 100644 guppy/prelude/integer.py delete mode 100644 guppy/statement.py create mode 100644 tests/error/linear_errors/unused_expr.err create mode 100644 tests/error/linear_errors/unused_expr.py diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 52f55df8..4ae52902 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - uses: actions/checkout@v3 diff --git a/guppy/__init__.py b/guppy/__init__.py index ff392293..bbe52f53 100644 --- a/guppy/__init__.py +++ b/guppy/__init__.py @@ -1 +1 @@ -__all__ = ["guppy_types"] +__all__ = ["types.py"] diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 356a28a7..8c4be4f7 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,6 +1,8 @@ import ast -from typing import Any, TypeVar, Generic, Union +from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING +if TYPE_CHECKING: + from guppy.gtypes import GuppyType AstNode = Union[ ast.AST, @@ -111,8 +113,92 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None: node.end_lineno = loc.end_lineno node.end_col_offset = loc.end_col_offset + source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc) + assert source is not None and file is not None and line_offset is not None + annotate_location(node, source, file, line_offset) -def is_empty_body(func_ast: ast.FunctionDef) -> bool: + +def annotate_location( + node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True +) -> None: + setattr(node, "line_offset", line_offset) + setattr(node, "file", file) + setattr(node, "source", source) + + if recurse: + for field, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + annotate_location(item, source, file, line_offset, recurse) + elif isinstance(value, ast.AST): + annotate_location(value, source, file, line_offset, recurse) + + +def get_file(node: AstNode) -> Optional[str]: + """Tries to retrieve a file annotation from an AST node.""" + try: + file = getattr(node, "file") + return file if isinstance(file, str) else None + except AttributeError: + return None + + +def get_source(node: AstNode) -> Optional[str]: + """Tries to retrieve a source annotation from an AST node.""" + try: + source = getattr(node, "source") + return source if isinstance(source, str) else None + except AttributeError: + return None + + +def get_line_offset(node: AstNode) -> Optional[int]: + """Tries to retrieve a line offset annotation from an AST node.""" + try: + line_offset = getattr(node, "line_offset") + return line_offset if isinstance(line_offset, int) else None + except AttributeError: + return None + + +A = TypeVar("A", bound=ast.AST) + + +def with_loc(loc: ast.AST, node: A) -> A: + """Copy source location from one AST node to the other.""" + set_location_from(node, loc) + return node + + +def with_type(ty: "GuppyType", node: A) -> A: + """Annotates an AST node with a type.""" + setattr(node, "type", ty) + return node + + +def get_type_opt(node: AstNode) -> Optional["GuppyType"]: + """Tries to retrieve a type annotation from an AST node.""" + from guppy.gtypes import GuppyType + + try: + ty = getattr(node, "type") + return ty if isinstance(ty, GuppyType) else None + except AttributeError: + return None + + +def get_type(node: AstNode) -> "GuppyType": + """Retrieve a type annotation from an AST node. + + Fails if the node is not annotated. + """ + ty = get_type_opt(node) + assert ty is not None + return ty + + +def has_empty_body(func_ast: ast.FunctionDef) -> bool: """Returns `True` if the body of a function definition is empty. This is the case if the body only contains a single `pass` statement or an ellipsis diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index dec4485a..95c9f63a 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -1,14 +1,14 @@ import ast +from abc import ABC from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING, Union, Any +from typing import Optional, TYPE_CHECKING, Union +from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.compiler_base import RawVariable, return_var -from guppy.guppy_types import FunctionType -from guppy.hugr.hugr import CFNode +from guppy.nodes import NestedFunctionDef if TYPE_CHECKING: - from guppy.cfg.cfg import CFG + from guppy.cfg.cfg import BaseCFG @dataclass @@ -34,61 +34,26 @@ def update_used(self, node: ast.AST) -> None: self.used[name.id] = name -VarRow = Sequence[RawVariable] - - -@dataclass(frozen=True) -class Signature: - """The signature of a basic block. - - Stores the inout/output variables with their types. - """ - - input_row: VarRow - output_rows: Sequence[VarRow] # One for each successor - - -@dataclass(frozen=True) -class CompiledBB: - """The result of compiling a basic block. - - Besides the corresponding node in the graph, we also store the signature of the - basic block with type information. - """ - - node: CFNode - bb: "BB" - sig: Signature - - -class NestedFunctionDef(ast.FunctionDef): - cfg: "CFG" - ty: FunctionType - - def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.cfg = cfg - self.ty = ty - - -BBStatement = Union[ast.Assign, ast.AugAssign, ast.Expr, ast.Return, NestedFunctionDef] +BBStatement = Union[ + ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef +] @dataclass(eq=False) # Disable equality to recover hash from `object` -class BB: +class BB(ABC): """A basic block in a control flow graph.""" idx: int # Pointer to the CFG that contains this node - cfg: "CFG" + containing_cfg: "BaseCFG[Self]" # AST statements contained in this BB statements: list[BBStatement] = field(default_factory=list) # Predecessor and successor BBs - predecessors: list["BB"] = field(default_factory=list) - successors: list["BB"] = field(default_factory=list) + predecessors: list[Self] = field(default_factory=list) + successors: list[Self] = field(default_factory=list) # If the BB has multiple successors, we need a predicate to decide to which one to # jump to @@ -107,13 +72,9 @@ def vars(self) -> VariableStats: assert self._vars is not None return self._vars - def compute_variable_stats(self, num_returns: int) -> None: - """Determines which variables are assigned/used in this BB. - - This also requires the expected number of returns of the whole CFG in order to - process `return` statements. - """ - visitor = VariableVisitor(self, num_returns) + def compute_variable_stats(self) -> None: + """Determines which variables are assigned/used in this BB.""" + visitor = VariableVisitor(self) for s in self.statements: visitor.visit(s) self._vars = visitor.stats @@ -121,26 +82,15 @@ def compute_variable_stats(self, num_returns: int) -> None: if self.branch_pred is not None: self._vars.update_used(self.branch_pred) - # In the `StatementCompiler`, we're going to turn return statements into - # assignments of dummy variables `%ret_xxx`. Thus, we have to register those - # variables as being used in the exit BB - if len(self.successors) == 0: - self._vars.used |= { - return_var(i): ast.Name(return_var(i), ast.Load) - for i in range(num_returns) - } - class VariableVisitor(ast.NodeVisitor): """Visitor that computes used and assigned variables in a BB.""" bb: BB stats: VariableStats - num_returns: int - def __init__(self, bb: BB, num_returns: int): + def __init__(self, bb: BB): self.bb = bb - self.num_returns = num_returns self.stats = VariableStats() def visit_Assign(self, node: ast.Assign) -> None: @@ -155,14 +105,11 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None: for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node - def visit_Return(self, node: ast.Return) -> None: - if node.value is not None: + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + if node.value: self.stats.update_used(node.value) - - # In the `StatementCompiler`, we're going to turn return statements into - # assignments of dummy variables `%ret_xxx`. To make the liveness analysis work, - # we have to register those variables as being assigned here - self.stats.assigned |= {return_var(i): node for i in range(self.num_returns)} + for name in name_nodes_in_ast(node.target): + self.stats.assigned[name.id] = node def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # In order to compute the used external variables in a nested function @@ -170,7 +117,7 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: from guppy.cfg.analysis import LivenessAnalysis for bb in node.cfg.bbs: - bb.compute_variable_stats(len(node.ty.returns)) + bb.compute_variable_stats() live = LivenessAnalysis().run(node.cfg.bbs) # Only store used *external* variables: things defined in the current BB, as diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 80e98559..e937e144 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,13 +1,14 @@ import ast import itertools -from typing import Optional, Iterator, Union, NamedTuple +from typing import Optional, Iterator, NamedTuple from guppy.ast_util import set_location_from, AstVisitor -from guppy.cfg.bb import BB, NestedFunctionDef +from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG -from guppy.compiler_base import Globals +from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError - +from guppy.gtypes import NoneType +from guppy.nodes import NestedFunctionDef # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -31,10 +32,9 @@ class CFGBuilder(AstVisitor[Optional[BB]]): """Constructs a CFG from ast nodes.""" cfg: CFG - num_returns: int globals: Globals - def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CFG: + def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG: """Builds a CFG from a list of ast nodes. We also require the expected number of return ports for the whole CFG. This is @@ -42,7 +42,6 @@ def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CF variables. """ self.cfg = CFG() - self.num_returns = num_returns self.globals = globals final_bb = self.visit_stmts( @@ -52,7 +51,7 @@ def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CF # If we're still in a basic block after compiling the whole body, we have to add # an implicit void return if final_bb is not None: - if num_returns > 0: + if not returns_none: raise GuppyError("Expected return statement", nodes[-1]) self.cfg.link(final_bb, self.cfg.exit_bb) @@ -76,15 +75,13 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[B bb_opt = self.visit(node, bb_opt, jumps) return bb_opt - def _build_node_value( - self, node: Union[ast.Assign, ast.AugAssign, ast.Return, ast.Expr], bb: BB - ) -> BB: + def _build_node_value(self, node: BBStatement, bb: BB) -> BB: """Utility method for building a node containing a `value` expression. Builds the expression and mutates `node.value` to point to the built expression. Returns the BB in which the expression is available and adds the node to it. """ - if node.value is not None: + if not isinstance(node, NestedFunctionDef) and node.value is not None: node.value, bb = ExprBuilder.build(node.value, self.cfg, bb) bb.statements.append(node) return bb @@ -97,6 +94,11 @@ def visit_AugAssign( ) -> Optional[BB]: return self._build_node_value(node, bb) + def visit_AnnAssign( + self, node: ast.AnnAssign, bb: BB, jumps: Jumps + ) -> Optional[BB]: + return self._build_node_value(node, bb) + def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]: # This is an expression statement where the value is discarded node.value, bb = ExprBuilder.build(node.value, self.cfg, bb) @@ -166,10 +168,11 @@ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> Optional[BB]: def visit_FunctionDef( self, node: ast.FunctionDef, bb: BB, jumps: Jumps ) -> Optional[BB]: - from guppy.function import FunctionDefCompiler + from guppy.checker.func_checker import check_signature - func_ty = FunctionDefCompiler.validate_signature(node, self.globals) - cfg = CFGBuilder().build(node.body, len(func_ty.returns), self.globals) + func_ty = check_signature(node, self.globals) + returns_none = isinstance(func_ty.returns, NoneType) + cfg = CFGBuilder().build(node.body, returns_none, self.globals) new_node = NestedFunctionDef( cfg, diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py index ef6f0314..2648f696 100644 --- a/guppy/cfg/cfg.py +++ b/guppy/cfg/cfg.py @@ -1,5 +1,4 @@ -import collections -from typing import Optional +from typing import Optional, TypeVar, Generic from guppy.cfg.analysis import ( LivenessDomain, @@ -9,41 +8,52 @@ MaybeAssignmentDomain, Result, ) -from guppy.cfg.bb import ( - BB, - VarRow, - Signature, - CompiledBB, - BBStatement, -) -from guppy.compiler_base import DFContainer, Variable, Globals, is_return_var -from guppy.error import GuppyError, GuppyTypeError -from guppy.ast_util import line_col -from guppy.expression import ExpressionCompiler -from guppy.guppy_types import GuppyType, TupleType, SumType -from guppy.hugr.hugr import Node, Hugr, OutPortV -from guppy.statement import StatementCompiler +from guppy.cfg.bb import BB, BBStatement + +T = TypeVar("T", bound=BB) -class CFG: - """A control-flow graph of basic blocks.""" - bbs: list[BB] - entry_bb: BB - exit_bb: BB +class BaseCFG(Generic[T]): + """Abstract base class for control-flow graphs.""" + + bbs: list[T] + entry_bb: T + exit_bb: T live_before: Result[LivenessDomain] ass_before: Result[DefAssignmentDomain] maybe_ass_before: Result[MaybeAssignmentDomain] - def __init__(self) -> None: - self.bbs = [] - self.entry_bb = self.new_bb() - self.exit_bb = self.new_bb() + def __init__( + self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None + ): + self.bbs = bbs + if entry_bb: + self.entry_bb = entry_bb + if exit_bb: + self.exit_bb = exit_bb self.live_before = {} self.ass_before = {} self.maybe_ass_before = {} + def analyze(self, def_ass_before: set[str], maybe_ass_before: set[str]) -> None: + for bb in self.bbs: + bb.compute_variable_stats() + self.live_before = LivenessAnalysis().run(self.bbs) + self.ass_before, self.maybe_ass_before = AssignmentAnalysis( + self.bbs, def_ass_before, maybe_ass_before + ).run_unpacked(self.bbs) + + +class CFG(BaseCFG[BB]): + """A control-flow graph of unchecked basic blocks.""" + + def __init__(self) -> None: + super().__init__([]) + self.entry_bb = self.new_bb() + self.exit_bb = self.new_bb() + def new_bb(self, *preds: BB, statements: Optional[list[BBStatement]] = None) -> BB: """Adds a new basic block to the CFG.""" bb = BB( @@ -58,266 +68,3 @@ def link(self, src_bb: BB, tgt_bb: BB) -> None: """Adds a control-flow edge between two basic blocks.""" src_bb.successors.append(tgt_bb) tgt_bb.predecessors.append(src_bb) - - def analyze( - self, num_returns: int, def_ass_before: set[str], maybe_ass_before: set[str] - ) -> None: - for bb in self.bbs: - bb.compute_variable_stats(num_returns) - self.live_before = LivenessAnalysis().run(self.bbs) - self.ass_before, self.maybe_ass_before = AssignmentAnalysis( - self.bbs, def_ass_before, maybe_ass_before - ).run_unpacked(self.bbs) - - def compile( - self, - graph: Hugr, - input_row: VarRow, - return_tys: list[GuppyType], - parent: Node, - globals: Globals, - ) -> None: - """Compiles the CFG.""" - - # First, we need to run program analysis - ass_before = {v.name for v in input_row} - self.analyze(len(return_tys), ass_before, ass_before) - - # We start by compiling the entry BB - entry_compiled = self._compile_bb( - self.entry_bb, input_row, return_tys, graph, parent, globals - ) - compiled = {self.entry_bb: entry_compiled} - - # Visit all control-flow edges in BFS order. We can't just do a normal loop over - # all BBs since the input types for a BB are computed by compiling a predecessor - queue = collections.deque( - (entry_compiled, i, succ) for i, succ in enumerate(self.entry_bb.successors) - ) - while len(queue) > 0: - pred, num_output, bb = queue.popleft() - out_row = pred.sig.output_rows[num_output] - - if bb in compiled: - # If the BB was already compiled, we just have to check that the - # signatures match. - self._check_rows_match(out_row, compiled[bb].sig.input_row, bb) - else: - # Otherwise, compile the BB and enqueue its successors - compiled_bb = self._compile_bb( - bb, out_row, return_tys, graph, parent, globals - ) - queue += [ - (compiled_bb, i, succ) for i, succ in enumerate(bb.successors) - ] - compiled[bb] = compiled_bb - - graph.add_edge( - pred.node.out_port(num_output), compiled[bb].node.in_port(None) - ) - - def _compile_bb( - self, - bb: BB, - input_row: VarRow, - return_tys: list[GuppyType], - graph: Hugr, - parent: Node, - globals: Globals, - ) -> CompiledBB: - """Compiles a single basic block.""" - - # The exit BB is completely empty - if len(bb.successors) == 0: - block = graph.add_exit(return_tys, parent) - return CompiledBB(block, bb, Signature(input_row, [])) - - # For the entry BB we have to separately check that all used variables are - # defined. For all other BBs, this will be checked when compiling a predecessor. - if len(bb.predecessors) == 0: - for x, use in bb.vars.used.items(): - if x not in self.ass_before[bb] and x not in globals.values: - raise GuppyError(f"Variable `{x}` is not defined", use) - - # Compile the basic block - block = graph.add_block(parent, num_successors=len(bb.successors)) - inp = graph.add_input(output_tys=[v.ty for v in input_row], parent=block) - dfg = DFContainer( - block, - { - v.name: Variable(v.name, inp.out_port(i), v.defined_at) - for (i, v) in enumerate(input_row) - }, - ) - stmt_compiler = StatementCompiler(graph, globals) - dfg = stmt_compiler.compile_stmts(bb.statements, bb, dfg, return_tys) - - # If we branch, we also have to compile the branch predicate - if len(bb.successors) > 1: - assert bb.branch_pred is not None - expr_compiler = ExpressionCompiler(graph, globals) - port = expr_compiler.compile(bb.branch_pred, dfg) - func = globals.get_instance_func(port.ty, "__bool__") - if func is None: - raise GuppyTypeError( - f"Expression of type `{port.ty}` cannot be interpreted as a `bool`", - bb.branch_pred, - ) - [branch_port] = func.compile_call( - [port], dfg, graph, globals, bb.branch_pred - ) - - for succ in bb.successors: - for x, use_bb in self.live_before[succ].items(): - # Check that the variable requested by the successor are defined - if x not in dfg and x not in globals.values: - # If the variable is defined on *some* paths, we can give a more - # informative error message - if x in self.maybe_ass_before[use_bb]: - # TODO: This should be "Variable x is not defined when coming - # from {bb}". But for this we need a way to associate BBs with - # source locations. - raise GuppyError( - f"Variable `{x}` is not defined on all control-flow paths.", - use_bb.vars.used[x], - ) - raise GuppyError( - f"Variable `{x}` is not defined", use_bb.vars.used[x] - ) - - # We have to check that used linear variables are not being outputted - if x in dfg: - var = dfg[x] - if var.ty.linear and var.used: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - self.live_before[succ][x].vars.used[x], - [var.used], - ) - - # On the other hand, unused linear variables *must* be outputted - for x, var in dfg.variables.items(): - if var.ty.linear and not var.used and x not in self.live_before[succ]: - # TODO: This should be "Variable x with linear type ty is not - # used in {bb}". But for this we need a way to associate BBs with - # source locations. - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is " - "not used on all control-flow paths", - var.defined_at, - ) - - # Finally, we have to add the block output. The easy case is if we don't branch: - # We just output the variables that are live in the successor - output_vars = sorted( - dfg[x] for x in self.live_before[bb.successors[0]] if x in dfg - ) - if len(bb.successors) == 1: - # Even if we don't branch, we still have to add a `Sum(())` predicate - unit = graph.add_make_tuple([], parent=block).out_port(0) - branch_port = graph.add_tag( - variants=[TupleType([])], tag=0, inp=unit, parent=block - ).out_port(0) - else: - # If we branch and the branches use different variables, we have to output a - # Sum-type predicate - first, *rest = bb.successors - if any( - self.live_before[r].keys() & dfg.variables.keys() - != self.live_before[first].keys() & dfg.variables.keys() - for r in rest - ): - # We put all non-linear variables into the branch predicate and all - # linear variables in the normal output (since they are shared between - # all successors). This is in line with the definition of `<` on - # variables which puts linear variables at the end. The only exception - # are return vars which must be outputted in order. - branch_port = self._choose_vars_for_pred( - graph=graph, - pred=branch_port, - output_vars=[ - sorted( - x - for x in self.live_before[succ] - if x in dfg and (not dfg[x].ty.linear or is_return_var(x)) - ) - for succ in bb.successors - ], - dfg=dfg, - ) - output_vars = sorted( - dfg[x] - # We can look at `successors[0]` here since all successors must have - # the same `live_before` linear variables - for x in self.live_before[bb.successors[0]] - if x in dfg and dfg[x].ty.linear and not is_return_var(x) - ) - - graph.add_output( - inputs=[branch_port] + [v.port for v in output_vars], parent=block - ) - output_rows = [ - sorted([dfg[x] for x in self.live_before[succ] if x in dfg]) - for succ in bb.successors - ] - - return CompiledBB(block, bb, Signature(input_row, output_rows)) - - def _check_rows_match(self, row1: VarRow, row2: VarRow, bb: BB) -> None: - """Checks that the types of two rows match up. - - Otherwise, an error is thrown, alerting the user that a variable has different - types on different control-flow paths. - """ - assert len(row1) == len(row2) - for v1, v2 in zip(row1, row2): - assert v1.name == v2.name - if v1.ty != v2.ty: - # In the error message, we want to mention the variable that was first - # defined at the start. - if ( - v1.defined_at - and v2.defined_at - and line_col(v2.defined_at) < line_col(v1.defined_at) - ): - v1, v2 = v2, v1 - # We shouldn't mention temporary variables (starting with `%`) - # in error messages: - ident = ( - "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" - ) - raise GuppyError( - f"{ident} can refer to different types: " - f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", - self.live_before[bb][v1.name].vars.used[v1.name], - [v1.defined_at, v2.defined_at], - ) - - @staticmethod - def _choose_vars_for_pred( - graph: Hugr, pred: OutPortV, output_vars: list[list[str]], dfg: DFContainer - ) -> OutPortV: - """Selects an output based on a predicate. - - Given `pred: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, - constructs a predicate value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. - """ - assert isinstance(pred.ty, SumType) - assert len(pred.ty.element_types) == len(output_vars) - tuples = [ - graph.add_make_tuple( - inputs=[dfg[x].port for x in sorted(vs) if x in dfg], parent=dfg.node - ).out_port(0) - for vs in output_vars - ] - tys = [t.ty for t in tuples] - conditional = graph.add_conditional( - cond_input=pred, inputs=tuples, parent=dfg.node - ) - for i, ty in enumerate(tys): - case = graph.add_case(conditional) - inp = graph.add_input(output_tys=tys, parent=case).out_port(i) - tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) - graph.add_output(inputs=[tag], parent=case) - return conditional.add_out_port(SumType(tys)) diff --git a/guppy/checker/__init__.py b/guppy/checker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py new file mode 100644 index 00000000..787b2ccb --- /dev/null +++ b/guppy/checker/cfg_checker.py @@ -0,0 +1,223 @@ +"""Type checking code for control-flow graphs + +Operates on CFGs produced by the `CFGBuilder`. Produces a `CheckedCFG` consisting of +`CheckedBB`s with inferred type signatures. +""" + +import collections +from dataclasses import dataclass +from typing import Sequence + +from guppy.ast_util import line_col +from guppy.cfg.bb import BB +from guppy.cfg.cfg import CFG, BaseCFG +from guppy.checker.core import Globals, Context + +from guppy.checker.core import Variable +from guppy.checker.expr_checker import ExprSynthesizer, to_bool +from guppy.checker.stmt_checker import StmtChecker +from guppy.error import GuppyError +from guppy.gtypes import GuppyType + + +VarRow = Sequence[Variable] + + +@dataclass(frozen=True) +class Signature: + """The signature of a basic block. + + Stores the input/output variables with their types. + """ + + input_row: VarRow + output_rows: Sequence[VarRow] # One for each successor + + @staticmethod + def empty() -> "Signature": + return Signature([], []) + + +@dataclass(eq=False) # Disable equality to recover hash from `object` +class CheckedBB(BB): + """Basic block annotated with an input and output type signature.""" + + sig: Signature = Signature.empty() + + +class CheckedCFG(BaseCFG[CheckedBB]): + input_tys: list[GuppyType] + output_ty: GuppyType + + def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None: + super().__init__([]) + self.input_tys = input_tys + self.output_ty = output_ty + + +def check_cfg( + cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals +) -> CheckedCFG: + """Type checks a control-flow graph. + + Annotates the basic blocks with input and output type signatures and removes + unreachable blocks. + """ + # First, we need to run program analysis + ass_before = set(v.name for v in inputs) + cfg.analyze(ass_before, ass_before) + + # We start by compiling the entry BB + checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) + checked_cfg.entry_bb = check_bb( + cfg.entry_bb, checked_cfg, inputs, return_ty, globals + ) + compiled = {cfg.entry_bb: checked_cfg.entry_bb} + + # Visit all control-flow edges in BFS order. We can't just do a normal loop over + # all BBs since the input types for a BB are computed by checking a predecessor. + # We do BFS instead of DFS to get a better error ordering. + queue = collections.deque( + (checked_cfg.entry_bb, i, succ) + for i, succ in enumerate(cfg.entry_bb.successors) + ) + while len(queue) > 0: + pred, num_output, bb = queue.popleft() + input_row = [ + Variable(v.name, v.ty, v.defined_at, None) + for v in pred.sig.output_rows[num_output] + ] + + if bb in compiled: + # If the BB was already compiled, we just have to check that the signatures + # match. + check_rows_match(input_row, compiled[bb].sig.input_row, bb) + else: + # Otherwise, check the BB and enqueue its successors + checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) + queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)] + compiled[bb] = checked_bb + + # Link up BBs in the checked CFG + compiled[bb].predecessors.append(pred) + pred.successors[num_output] = compiled[bb] + + checked_cfg.bbs = list(compiled.values()) + checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable + checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} + checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} + checked_cfg.maybe_ass_before = { + compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs + } + return checked_cfg + + +def check_bb( + bb: BB, + checked_cfg: CheckedCFG, + inputs: VarRow, + return_ty: GuppyType, + globals: Globals, +) -> CheckedBB: + cfg = bb.containing_cfg + + # For the entry BB we have to separately check that all used variables are + # defined. For all other BBs, this will be checked when compiling a predecessor. + if bb == cfg.entry_bb: + assert len(bb.predecessors) == 0 + for x, use in bb.vars.used.items(): + if x not in cfg.ass_before[bb] and x not in globals.values: + raise GuppyError(f"Variable `{x}` is not defined", use) + + # Check the basic block + ctx = Context(globals, {v.name: v for v in inputs}) + checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) + + # If we branch, we also have to check the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + bb.branch_pred, ty = ExprSynthesizer(ctx).synthesize(bb.branch_pred) + bb.branch_pred, _ = to_bool(bb.branch_pred, ty, ctx) + + for succ in bb.successors: + for x, use_bb in cfg.live_before[succ].items(): + # Check that the variables requested by the successor are defined + if x not in ctx.locals and x not in ctx.globals.values: + # If the variable is defined on *some* paths, we can give a more + # informative error message + if x in cfg.maybe_ass_before[use_bb]: + # TODO: This should be "Variable x is not defined when coming + # from {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` is not defined on all control-flow paths.", + use_bb.vars.used[x], + ) + raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x]) + + # We have to check that used linear variables are not being outputted + if x in ctx.locals: + var = ctx.locals[x] + if var.ty.linear and var.used: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + cfg.live_before[succ][x].vars.used[x], + [var.used], + ) + + # On the other hand, unused linear variables *must* be outputted + for x, var in ctx.locals.items(): + if var.ty.linear and not var.used and x not in cfg.live_before[succ]: + # TODO: This should be "Variable x with linear type ty is not + # used in {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is " + "not used on all control-flow paths", + var.defined_at, + ) + + # Finally, we need to compute the signature of the basic block + outputs = [ + [ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals] + for succ in bb.successors + ] + + # Also prepare the successor list so we can fill it in later + checked_bb = CheckedBB( + bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) + ) + checked_bb.successors = [None] * len(bb.successors) # type: ignore[list-item] + checked_bb.branch_pred = bb.branch_pred + return checked_bb + + +def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: + """Checks that the types of two rows match up. + + Otherwise, an error is thrown, alerting the user that a variable has different + types on different control-flow paths. + """ + map1, map2 = {v.name: v for v in row1}, {v.name: v for v in row2} + assert map1.keys() == map2.keys() + for x in map1: + v1, v2 = map1[x], map2[x] + if v1.ty != v2.ty: + # In the error message, we want to mention the variable that was first + # defined at the start. + if ( + v1.defined_at + and v2.defined_at + and line_col(v2.defined_at) < line_col(v1.defined_at) + ): + v1, v2 = v2, v1 + # We shouldn't mention temporary variables (starting with `%`) + # in error messages: + ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" + raise GuppyError( + f"{ident} can refer to different types: " + f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", + bb.containing_cfg.live_before[bb][v1.name].vars.used[v1.name], + [v1.defined_at, v2.defined_at], + ) diff --git a/guppy/checker/core.py b/guppy/checker/core.py new file mode 100644 index 00000000..9bab6375 --- /dev/null +++ b/guppy/checker/core.py @@ -0,0 +1,106 @@ +import ast +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import NamedTuple, Optional, Union + +from guppy.ast_util import AstNode +from guppy.gtypes import ( + GuppyType, + FunctionType, + TupleType, + SumType, + NoneType, + BoolType, +) + + +@dataclass +class Variable: + """Class holding data associated with a variable.""" + + name: str + ty: GuppyType + defined_at: Optional[AstNode] + used: Optional[AstNode] + + +@dataclass +class CallableVariable(ABC, Variable): + """Abstract base class for global variables that can be called.""" + + ty: FunctionType + + @abstractmethod + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" + ) -> ast.expr: + """Checks the return type of a function call against a given type.""" + + @abstractmethod + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: "Context" + ) -> tuple[ast.expr, GuppyType]: + """Synthesizes the return type of a function call.""" + + +class Globals(NamedTuple): + """Collection of names that are available on module-level. + + Separately stores names that are bound to values (i.e. module-level functions or + constants), to types, or to instance functions belonging to types. + """ + + values: dict[str, Variable] + types: dict[str, type[GuppyType]] + + @staticmethod + def default() -> "Globals": + """Generates a `Globals` instance that is populated with all core types""" + tys: dict[str, type[GuppyType]] = { + FunctionType.name: FunctionType, + TupleType.name: TupleType, + SumType.name: SumType, + NoneType.name: NoneType, + BoolType.name: BoolType, + } + return Globals({}, tys) + + def get_instance_func(self, ty: GuppyType, name: str) -> Optional[CallableVariable]: + """Looks up an instance function with a given name for a type. + + Returns `None` if the name doesn't exist or isn't a function. + """ + qualname = qualified_name(ty.__class__, name) + if qualname in self.values: + val = self.values[qualname] + if isinstance(val, CallableVariable): + return val + return None + + def __or__(self, other: "Globals") -> "Globals": + return Globals( + self.values | other.values, + self.types | other.types, + ) + + def __ior__(self, other: "Globals") -> "Globals": + self.values.update(other.values) + self.types.update(other.types) + return self + + +# Local variable mapping +Locals = dict[str, Variable] + + +class Context(NamedTuple): + """The type checking context.""" + + globals: Globals + locals: Locals + + +def qualified_name(ty: Union[type[GuppyType], str], name: str) -> str: + """Returns a qualified name for an instance function on a type.""" + ty_name = ty if isinstance(ty, str) else ty.name + return f"{ty_name}.{name}" diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py new file mode 100644 index 00000000..09c12774 --- /dev/null +++ b/guppy/checker/expr_checker.py @@ -0,0 +1,364 @@ +"""Type checking and synthesizing code for expressions. + +Operates on expressions in a basic block after CFG construction. In particular, we +assume that expressions that involve control flow (i.e. short-circuiting and ternary +expressions) have been removed during CFG construction. + +Furthermore, we assume that assignment expressions with the walrus operator := have +been turned into regular assignments and are no longer present. As a result, expressions +are assumed to be side effect free, in the sense that they do not modify the variables +available in the type checking context. + +We may alter/desugar AST nodes during type checking. In particular, we turn `ast.Name` +nodes into either `LocalName` or `GlobalName` nodes and `ast.Call` nodes are turned into +`LocalCall` or `GlobalCall` nodes. Furthermore, all nodes in the resulting AST are +annotated with their type. + +Expressions can be checked against a given type by the `ExprChecker`, raising a type +error if the expressions doesn't have the expected type. Checking is used for annotated +assignments, return values, and function arguments. Alternatively, the `ExprSynthesizer` +can be used to infer a type for an expression. +""" + +import ast +from contextlib import suppress +from typing import Optional, Union, NoReturn, Any + +from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt +from guppy.checker.core import Context, CallableVariable, Globals +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType +from guppy.nodes import LocalName, GlobalName, LocalCall + +# Mapping from unary AST op to dunder method and display name +unary_table: dict[type[ast.unaryop], tuple[str, str]] = { + ast.UAdd: ("__pos__", "+"), + ast.USub: ("__neg__", "-"), + ast.Invert: ("__invert__", "~"), +} # fmt: skip + +# Mapping from binary AST op to left dunder method, right dunder method and display name +AstOp = Union[ast.operator, ast.cmpop] +binary_table: dict[type[AstOp], tuple[str, str, str]] = { + ast.Add: ("__add__", "__radd__", "+"), + ast.Sub: ("__sub__", "__rsub__", "-"), + ast.Mult: ("__mul__", "__rmul__", "*"), + ast.Div: ("__truediv__", "__rtruediv__", "/"), + ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), + ast.Mod: ("__mod__", "__rmod__", "%"), + ast.Pow: ("__pow__", "__rpow__", "**"), + ast.LShift: ("__lshift__", "__rlshift__", "<<"), + ast.RShift: ("__rshift__", "__rrshift__", ">>"), + ast.BitOr: ("__or__", "__ror__", "|"), + ast.BitXor: ("__xor__", "__rxor__", "^"), + ast.BitAnd: ("__and__", "__rand__", "&"), + ast.MatMult: ("__matmul__", "__rmatmul__", "@"), + ast.Eq: ("__eq__", "__eq__", "=="), + ast.NotEq: ("__neq__", "__neq__", "!="), + ast.Lt: ("__lt__", "__gt__", "<"), + ast.LtE: ("__le__", "__ge__", "<="), + ast.Gt: ("__gt__", "__lt__", ">"), + ast.GtE: ("__ge__", "__le__", ">="), +} # fmt: skip + + +class ExprChecker(AstVisitor[ast.expr]): + """Checks an expression against a type and produces a new type-annotated AST""" + + ctx: Context + + # Name for the kind of term we are currently checking against (used in errors). + # For example, "argument", "return value", or in general "expression". + _kind: str + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + self._kind = "expression" + + def _fail( + self, + expected: GuppyType, + actual: Union[ast.expr, GuppyType], + loc: Optional[AstNode] = None, + ) -> NoReturn: + """Raises a type error indicating that the type doesn't match.""" + if not isinstance(actual, GuppyType): + loc = loc or actual + _, actual = self._synthesize(actual) + if loc is None: + raise InternalGuppyError("Failure location is required") + raise GuppyTypeError( + f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc + ) + + def check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + """Checks an expression against a type. + + Returns a new desugared expression with type annotations. + """ + old_kind = self._kind + self._kind = kind or self._kind + expr = self.visit(expr, ty) + self._kind = old_kind + return with_type(ty, expr) + + def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Invokes the type synthesiser""" + return ExprSynthesizer(self.ctx).synthesize(node) + + def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: + if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): + return self._fail(ty, node) + for i, el in enumerate(node.elts): + node.elts[i] = self.check(el, ty.element_types[i]) + return node + + def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore[override] + # Try to synthesize and then check if it matches the given type + node, synth = self._synthesize(node) + if synth != ty: + self._fail(ty, synth, node) + return node + + +class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): + ctx: Context + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + + def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Tries to synthesise a type for the given expression. + + Also returns a new desugared expression with type annotations. + """ + if ty := get_type_opt(node): + return node, ty + node, ty = self.visit(node) + return with_type(ty, node), ty + + def _check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + """Checks an expression against a given type""" + return ExprChecker(self.ctx).check(expr, ty, kind) + + def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: + ty = python_value_to_guppy_type(node.value, node, self.ctx.globals) + if ty is None: + raise GuppyError("Unsupported constant", node) + return node, ty + + def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + x = node.id + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is not None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + node, + [var.used], + ) + var.used = node + return with_loc(node, LocalName(id=x)), var.ty + elif x in self.ctx.globals.values: + # Cache value in the AST + value = self.ctx.globals.values[x] + return with_loc(node, GlobalName(id=x, value=value)), value.ty + raise InternalGuppyError( + f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " + f"been caught by program analysis!" + ) + + def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: + elems = [self.synthesize(elem) for elem in node.elts] + node.elts = [n for n, _ in elems] + return node, TupleType([ty for _, ty in elems]) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: + # We need to synthesise the argument type, so we can look up dunder methods + node.operand, op_ty = self.synthesize(node.operand) + + # Special case for the `not` operation since it is not implemented via a dunder + # method or control-flow + if isinstance(node.op, ast.Not): + node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx) + return node, bool_ty + + # Check all other unary expressions by calling out to instance dunder methods + op, display_name = unary_table[node.op.__class__] + func = self.ctx.globals.get_instance_func(op_ty, op) + if func is None: + raise GuppyTypeError( + f"Unary operator `{display_name}` not defined for argument of type " + f" `{op_ty}`", + node.operand, + ) + return func.synthesize_call([node.operand], node, self.ctx) + + def _synthesize_binary( + self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr + ) -> tuple[ast.expr, GuppyType]: + """Helper method to compile binary operators by calling out to dunder methods. + + For example, first try calling `__add__` on the left operand. If that fails, try + `__radd__` on the right operand. + """ + if op.__class__ not in binary_table: + raise GuppyError("This binary operation is not supported by Guppy.", op) + lop, rop, display_name = binary_table[op.__class__] + left_expr, left_ty = self.synthesize(left_expr) + right_expr, right_ty = self.synthesize(right_expr) + + if func := self.ctx.globals.get_instance_func(left_ty, lop): + with suppress(GuppyError): + return func.synthesize_call([left_expr, right_expr], node, self.ctx) + + if func := self.ctx.globals.get_instance_func(right_ty, rop): + with suppress(GuppyError): + return func.synthesize_call([right_expr, left_expr], node, self.ctx) + + raise GuppyTypeError( + f"Binary operator `{display_name}` not defined for arguments of type " + f"`{left_ty}` and `{right_ty}`", + node, + ) + + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: + return self._synthesize_binary(node.left, node.right, node.op, node) + + def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: + if len(node.comparators) != 1 or len(node.ops) != 1: + raise InternalGuppyError( + "BB contains chained comparison. Should have been removed during CFG " + "construction." + ) + left_expr, [op], [right_expr] = node.left, node.ops, node.comparators + return self._synthesize_binary(left_expr, right_expr, op, node) + + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: + if len(node.keywords) > 0: + raise GuppyError("Keyword arguments are not supported", node.keywords[0]) + node.func, ty = self.synthesize(node.func) + + # First handle direct calls of user-defined functions and extension functions + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): + return node.func.value.synthesize_call(node.args, node, self.ctx) + + # Otherwise, it must be a function as a higher-order value + if isinstance(ty, FunctionType): + args, return_ty = synthesize_call(ty, node.args, node, self.ctx) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif f := self.ctx.globals.get_instance_func(ty, "__call__"): + return f.synthesize_call(node.args, node, self.ctx) + else: + raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `NamedExpr`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + + def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `BoolOp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `IfExp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + +def check_num_args(exp: int, act: int, node: AstNode) -> None: + """Checks that the correct number of arguments have been passed to a function.""" + if act < exp: + raise GuppyTypeError( + f"Not enough arguments passed (expected {exp}, got {act})", node + ) + if exp < act: + if isinstance(node, ast.Call): + raise GuppyTypeError("Unexpected argument", node.args[exp]) + raise GuppyTypeError( + f"Too many arguments passed (expected {exp}, got {act})", node + ) + + +def synthesize_call( + func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context +) -> tuple[list[ast.expr], GuppyType]: + """Synthesizes the return type of a function call. + + Also returns desugared versions of the arguments with type annotations. + """ + check_num_args(len(func_ty.args), len(args), node) + for i, arg in enumerate(args): + args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") + return args, func_ty.returns + + +def check_call( + func_ty: FunctionType, + args: list[ast.expr], + ty: GuppyType, + node: AstNode, + ctx: Context, +) -> list[ast.expr]: + """Checks the return type of a function call against a given type""" + args, return_ty = synthesize_call(func_ty, args, node, ctx) + if return_ty != ty: + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `{return_ty}`", node + ) + return args + + +def to_bool( + node: ast.expr, node_ty: GuppyType, ctx: Context +) -> tuple[ast.expr, GuppyType]: + """Tries to turn a node into a bool""" + if isinstance(node_ty, BoolType): + return node, node_ty + + func = ctx.globals.get_instance_func(node_ty, "__bool__") + if func is None: + raise GuppyTypeError( + f"Expression of type `{node_ty}` cannot be interpreted as a `bool`", + node, + ) + + # We could check the return type against bool, but we can give a better error + # message if we synthesise and compare to bool by hand + call, return_ty = func.synthesize_call([node], node, ctx) + if not isinstance(return_ty, BoolType): + raise GuppyTypeError( + f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", + node, + ) + return call, return_ty + + +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals +) -> Optional[GuppyType]: + """Turns a primitive Python value into a Guppy type. + + Returns `None` if the Python value cannot be represented in Guppy. + """ + match v: + case bool(): + return globals.types["bool"].build(node=node) + case int(): + return globals.types["int"].build(node=node) + case float(): + return globals.types["float"].build(node=node) + case _: + return None diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py new file mode 100644 index 00000000..e0695824 --- /dev/null +++ b/guppy/checker/func_checker.py @@ -0,0 +1,189 @@ +"""Type checking code for top-level and nested function definitions. + +For top-level functions, we take the `DefinedFunction` containing the `ast.FunctionDef` +node straight from the Python AST. We build a CFG, check it, and return a +`CheckedFunction` containing a `CheckedCFG` with type annotations. +""" + +import ast +from dataclasses import dataclass + +from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc +from guppy.cfg.bb import BB +from guppy.cfg.builder import CFGBuilder +from guppy.checker.core import Variable, Globals, Context, CallableVariable +from guppy.checker.cfg_checker import check_cfg, CheckedCFG +from guppy.checker.expr_checker import synthesize_call, check_call +from guppy.error import GuppyError +from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType +from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef + + +@dataclass +class DefinedFunction(CallableVariable): + """A user-defined function""" + + ty: FunctionType + defined_at: ast.FunctionDef + + @staticmethod + def from_ast( + func_def: ast.FunctionDef, name: str, globals: Globals + ) -> "DefinedFunction": + ty = check_signature(func_def, globals) + return DefinedFunction(name, ty, func_def, None) + + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> GlobalCall: + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) + + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[GlobalCall, GuppyType]: + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty + + +@dataclass +class CheckedFunction(DefinedFunction): + """Type checked version of a user-defined function""" + + cfg: CheckedCFG + + +def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFunction: + """Type checks a top-level function definition.""" + func_def = func.defined_at + args = func_def.args.args + returns_none = isinstance(func.ty.returns, NoneType) + assert func.ty.arg_names is not None + + cfg = CFGBuilder().build(func_def.body, returns_none, globals) + inputs = [ + Variable(x, ty, loc, None) + for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) + ] + cfg = check_cfg(cfg, inputs, func.ty.returns, globals) + return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) + + +def check_nested_func_def( + func_def: NestedFunctionDef, bb: BB, ctx: Context +) -> CheckedNestedFunctionDef: + """Type checks a local (nested) function definition.""" + func_ty = check_signature(func_def, ctx.globals) + assert func_ty.arg_names is not None + + # We've already built the CFG for this function while building the CFG of the + # enclosing function + cfg = func_def.cfg + + # Find captured variables + parent_cfg = bb.containing_cfg + def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() + maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] + cfg.analyze(def_ass_before, maybe_ass_before) + captured = { + x: ctx.locals[x] + for x in cfg.live_before[cfg.entry_bb] + if x not in func_ty.arg_names and x in ctx.locals + } + + # Captured variables may not be linear + for v in captured.values(): + if v.ty.linear: + x = v.name + using_bb = cfg.live_before[cfg.entry_bb][x] + raise GuppyError( + f"Variable `{x}` with linear type `{v.ty}` may not be used here " + f"because it was defined in an outer scope (at {{0}})", + using_bb.vars.used[x], + [v.defined_at], + ) + + # Captured variables may never be assigned to + for bb in cfg.bbs: + for v in captured.values(): + x = v.name + if x in bb.vars.assigned: + raise GuppyError( + f"Variable `{x}` defined in an outer scope (at {{0}}) may not " + f"be assigned to", + bb.vars.assigned[x], + [v.defined_at], + ) + + # Construct inputs for checking the body CFG + inputs = list(captured.values()) + [ + Variable(x, ty, func_def.args.args[i], None) + for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) + ] + globals = ctx.globals + + # Check if the body contains a free (recursive) occurrence of the function name. + # By checking if the name is free at the entry BB, we avoid false positives when + # a user shadows the name with a local variable + if func_def.name in cfg.live_before[cfg.entry_bb]: + if not captured: + # If there are no captured vars, we treat the function like a global name + func = DefinedFunction(func_def.name, func_ty, func_def, None) + globals = ctx.globals | Globals({func_def.name: func}, {}) + + else: + # Otherwise, we treat it like a local name + inputs.append(Variable(func_def.name, func_def.ty, func_def, None)) + + checked_cfg = check_cfg(cfg, inputs, func_ty.returns, globals) + checked_def = CheckedNestedFunctionDef( + checked_cfg, + func_ty, + captured, + name=func_def.name, + args=func_def.args, + body=func_def.body, + decorator_list=func_def.decorator_list, + returns=func_def.returns, + type_comment=func_def.type_comment, + ) + return with_loc(func_def, checked_def) + + +def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: + """Checks the signature of a function definition and returns the corresponding + Guppy type.""" + if len(func_def.args.posonlyargs) != 0: + raise GuppyError( + "Positional-only parameters not supported", func_def.args.posonlyargs[0] + ) + if len(func_def.args.kwonlyargs) != 0: + raise GuppyError( + "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] + ) + if func_def.args.vararg is not None: + raise GuppyError("*args not supported", func_def.args.vararg) + if func_def.args.kwarg is not None: + raise GuppyError("**kwargs not supported", func_def.args.kwarg) + if func_def.returns is None: + # TODO: Error location is incorrect + if all(r.value is None for r in return_nodes_in_ast(func_def)): + raise GuppyError( + "Return type must be annotated. Try adding a `-> None` annotation.", + func_def, + ) + raise GuppyError("Return type must be annotated", func_def) + + arg_tys = [] + arg_names = [] + for i, arg in enumerate(func_def.args.args): + if arg.annotation is None: + raise GuppyError("Argument type must be annotated", arg) + ty = type_from_ast(arg.annotation, globals) + arg_tys.append(ty) + arg_names.append(arg.arg) + + ret_type = type_from_ast(func_def.returns, globals) + return FunctionType(arg_tys, ret_type, arg_names) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py new file mode 100644 index 00000000..5622ee55 --- /dev/null +++ b/guppy/checker/stmt_checker.py @@ -0,0 +1,140 @@ +"""Type checking code for statements. + +Operates on statements in a basic block after CFG construction. In particular, we +assume that statements involving control flow (i.e. if, while, break, and return +statements) have been removed during CFG construction. + +After checking, we return a desugared statement where all sub-expression have been type +annotated. +""" + +import ast +from typing import Sequence + +from guppy.ast_util import with_loc, AstVisitor +from guppy.cfg.bb import BB, BBStatement +from guppy.checker.core import Variable, Context +from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType +from guppy.nodes import NestedFunctionDef + + +class StmtChecker(AstVisitor[BBStatement]): + ctx: Context + bb: BB + return_ty: GuppyType + + def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: + self.ctx = ctx + self.bb = bb + self.return_ty = return_ty + + def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: + return [self.visit(s) for s in stmts] + + def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + return ExprSynthesizer(self.ctx).synthesize(node) + + def _check_expr( + self, node: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + return ExprChecker(self.ctx).check(node, ty, kind) + + def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: + """Helper function to check assignments with patterns.""" + match lhs: + # Easiest case is if the LHS pattern is a single variable. + case ast.Name(id=x): + # Check if we override an unused linear variable + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + self.ctx.locals[x] = Variable(x, ty, node, None) + + # The only other thing we support right now are tuples + case ast.Tuple(elts=elts): + tys = ty.element_types if isinstance(ty, TupleType) else [ty] + n, m = len(elts), len(tys) + if n != m: + raise GuppyTypeError( + f"{'Too many' if n < m else 'Not enough'} values to unpack " + f"(expected {n}, got {m})", + node, + ) + for pat, el_ty in zip(elts, tys): + self._check_assign(pat, el_ty, node) + + # TODO: Python also supports assignments like `[a, b] = [1, 2]` or + # `a, *b = ...`. The former would require some runtime checks but + # the latter should be easier to do (unpack and repack the rest). + case _: + raise GuppyError("Assignment pattern not supported", lhs) + + def visit_Assign(self, node: ast.Assign) -> ast.stmt: + if len(node.targets) > 1: + # This is the case for assignments like `a = b = 1` + raise GuppyError("Multi assignment not supported", node) + + [target] = node.targets + node.value, ty = self._synth_expr(node.value) + self._check_assign(target, ty, node) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: + if node.value is None: + raise GuppyError( + "Variable declaration is not supported. Assignment is required", node + ) + ty = type_from_ast(node.annotation, self.ctx.globals) + node.value = self._check_expr(node.value, ty) + self._check_assign(node.target, ty, node) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: + bin_op = with_loc( + node, ast.BinOp(left=node.target, op=node.op, right=node.value) + ) + assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) + return self.visit_Assign(assign) + + def visit_Expr(self, node: ast.Expr) -> ast.stmt: + # An expression statement where the return value is discarded + node.value, ty = self._synth_expr(node.value) + if ty.linear: + raise GuppyTypeError(f"Value with linear type `{ty}` is not used", node) + return node + + def visit_Return(self, node: ast.Return) -> ast.stmt: + if node.value is not None: + node.value = self._check_expr(node.value, self.return_ty, "return value") + elif not isinstance(self.return_ty, NoneType): + raise GuppyTypeError( + f"Expected return value of type `{self.return_ty}`", None + ) + return node + + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: + from guppy.checker.func_checker import check_nested_func_def + + func_def = check_nested_func_def(node, self.bb, self.ctx) + self.ctx.locals[func_def.name] = Variable( + func_def.name, func_def.ty, func_def, None + ) + return func_def + + def visit_If(self, node: ast.If) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_While(self, node: ast.While) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Break(self, node: ast.Break) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Continue(self, node: ast.Continue) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") diff --git a/guppy/compiler.py b/guppy/compiler.py deleted file mode 100644 index da4f3f20..00000000 --- a/guppy/compiler.py +++ /dev/null @@ -1,207 +0,0 @@ -import ast -import functools -import inspect -import sys -import textwrap -from dataclasses import dataclass -from types import ModuleType -from typing import Optional, Any, Callable, Union - -from guppy.ast_util import is_empty_body -from guppy.compiler_base import Globals -from guppy.extension import GuppyExtension -from guppy.function import FunctionDefCompiler, DefinedFunction -from guppy.hugr.hugr import Hugr -from guppy.error import GuppyError, SourceLoc - - -def format_source_location( - source_lines: list[str], - loc: Union[ast.AST, ast.operator, ast.expr, ast.arg, ast.Name], - line_offset: int, - num_lines: int = 3, - indent: int = 4, -) -> str: - """Creates a pretty banner to show source locations for errors.""" - assert loc.end_col_offset is not None # TODO - s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() - s += "\n" + loc.col_offset * " " + (loc.end_col_offset - loc.col_offset) * "^" - s = textwrap.dedent(s).splitlines() - # Add line numbers - line_numbers = [ - str(line_offset + loc.lineno - i) + ":" for i in range(num_lines, 0, -1) - ] - longest = max(len(ln) for ln in line_numbers) - prefixes = [ln + " " * (longest - len(ln) + indent) for ln in line_numbers] - res = "".join(prefix + line + "\n" for prefix, line in zip(prefixes, s[:-1])) - res += (longest + indent) * " " + s[-1] - return res - - -@dataclass -class RawFunction: - pyfun: Callable[..., Any] - ast: ast.FunctionDef - source_lines: list[str] - line_offset: int - - -class GuppyModule(object): - """A Guppy module backed by a Hugr graph. - - Instances of this class can be used as a decorator to add functions to the module. - After all functions are added, `compile()` must be called to obtain the Hugr. - """ - - name: str - globals: Globals - - _func_defs: dict[str, RawFunction] - _func_decls: dict[str, RawFunction] - - def __init__(self, name: str): - self.name = name - self.globals = Globals.default() - self._func_defs = {} - self._func_decls = {} - - # Load all prelude extensions - import guppy.prelude.builtin - import guppy.prelude.boolean - import guppy.prelude.float - import guppy.prelude.integer - - self.load(guppy.prelude.builtin) - self.load(guppy.prelude.boolean) - self.load(guppy.prelude.float) - self.load(guppy.prelude.integer) - - def register_func(self, f: Callable[..., Any]) -> None: - """Registers a Python function as belonging to this Guppy module. - - This can be used for both function definitions and declarations. To mark a - declaration, the body of the function may only contain an ellipsis expression. - """ - func = self._parse(f) - if is_empty_body(func.ast): - self._func_decls[func.ast.name] = func - else: - self._func_defs[func.ast.name] = func - - def load(self, m: Union[ModuleType, GuppyExtension]) -> None: - """Loads a Guppy extension from a python module. - - This function must be called for names from the extension to become available in - the Guppy. - """ - if isinstance(m, GuppyExtension): - self.globals |= m.globals - else: - for ext in m.__dict__.values(): - if isinstance(ext, GuppyExtension): - self.globals |= ext.globals - - def _parse(self, f: Callable[..., Any]) -> RawFunction: - source_lines, line_offset = inspect.getsourcelines(f) - line_offset -= 1 - source = "".join(source_lines) # Lines already have trailing \n's - source = textwrap.dedent(source) - func_ast = ast.parse(source).body[0] - if not isinstance(func_ast, ast.FunctionDef): - raise GuppyError("Only functions can be placed in modules", func_ast) - if func_ast.name in self._func_defs: - raise GuppyError( - f"Module `{self.name}` already contains a function named `{func_ast.name}` " - f"(declared at {SourceLoc.from_ast(self._func_defs[func_ast.name].ast, line_offset)})", - func_ast, - ) - return RawFunction(f, func_ast, source_lines, line_offset) - - def compile(self, exit_on_error: bool = False) -> Optional[Hugr]: - """Compiles the module and returns the final Hugr.""" - graph = Hugr(self.name) - module_node = graph.set_root_name(self.name) - try: - # Generate nodes for all function definition and declarations and add them - # to the globals - defs = {} - for name, f in self._func_defs.items(): - func_ty = FunctionDefCompiler.validate_signature(f.ast, self.globals) - def_node = graph.add_def(func_ty, module_node, f.ast.name) - defs[name] = def_node - self.globals.values[name] = DefinedFunction( - name, def_node.out_port(0), f.ast - ) - for name, f in self._func_decls.items(): - func_ty = FunctionDefCompiler.validate_signature(f.ast, self.globals) - if not is_empty_body(f.ast): - raise GuppyError( - "Function declarations may not have a body.", f.ast.body[0] - ) - decl_node = graph.add_declare(func_ty, module_node, f.ast.name) - self.globals.values[name] = DefinedFunction( - name, decl_node.out_port(0), f.ast - ) - - # Now compile functions definitions - for name, f in self._func_defs.items(): - FunctionDefCompiler(graph, self.globals).compile_global( - f.ast, defs[name] - ) - return graph - - except GuppyError as err: - if err.location: - loc = err.location - line = f.line_offset + loc.lineno - print( - "Guppy compilation failed. " - f"Error in file {inspect.getsourcefile(f.pyfun)}:{line}\n", - file=sys.stderr, - ) - print( - format_source_location(f.source_lines, loc, f.line_offset + 1), - file=sys.stderr, - ) - else: - print( - "Guppy compilation failed. " - f"Error in file {inspect.getsourcefile(f.pyfun)}\n", - file=sys.stderr, - ) - print( - f"{err.__class__.__name__}: {err.get_msg(f.line_offset)}", - file=sys.stderr, - ) - if exit_on_error: - sys.exit(1) - return None - - -def guppy( - arg: Union[Callable[..., Any], GuppyModule] -) -> Union[Optional[Hugr], Callable[[Callable[..., Any]], Callable[..., Any]]]: - """Decorator to annotate Python functions as Guppy code. - - Optionally, the `GuppyModule` in which the function should be placed can be passed - to the decorator. - """ - if isinstance(arg, GuppyModule): - - def dec(f: Callable[..., Any]) -> Callable[..., Any]: - assert isinstance(arg, GuppyModule) - arg.register_func(f) - - @functools.wraps(f) - def dummy(*args: Any, **kwargs: Any) -> Any: - raise GuppyError( - "Guppy functions can only be called in a Guppy context" - ) - - return dummy - - return dec - else: - module = GuppyModule("module") - module.register_func(arg) - return module.compile(exit_on_error=False) diff --git a/guppy/compiler/__init__.py b/guppy/compiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py new file mode 100644 index 00000000..086236e9 --- /dev/null +++ b/guppy/compiler/cfg_compiler.py @@ -0,0 +1,173 @@ +import functools +from typing import Sequence + +from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature +from guppy.checker.core import Variable +from guppy.compiler.core import ( + CompiledGlobals, + is_return_var, + DFContainer, + return_var, + PortVariable, +) +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.compiler.stmt_compiler import StmtCompiler +from guppy.gtypes import TupleType, SumType, type_to_row +from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV + + +def compile_cfg( + cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals +) -> None: + """Compiles a CFG to Hugr.""" + insert_return_vars(cfg) + + blocks: dict[CheckedBB, CFNode] = {} + for bb in cfg.bbs: + blocks[bb] = compile_bb(bb, graph, parent, globals) + for bb in cfg.bbs: + for succ in bb.successors: + graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) + + +def compile_bb( + bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals +) -> CFNode: + """Compiles a single basic block to Hugr.""" + inputs = sort_vars(bb.sig.input_row) + + # The exit BB is completely empty + if len(bb.successors) == 0: + assert len(bb.statements) == 0 + return graph.add_exit([v.ty for v in inputs], parent) + + # Otherwise, we use a regular `Block` node + block = graph.add_block(parent) + + # Add input node and compile the statements + inp = graph.add_input(output_tys=[v.ty for v in inputs], parent=block) + dfg = DFContainer( + block, + { + v.name: PortVariable(v.name, inp.out_port(i), v.defined_at, None) + for (i, v) in enumerate(inputs) + }, + ) + dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) + + # If we branch, we also have to compile the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg) + else: + # Even if we don't branch, we still have to add a `Sum(())` predicates + unit = graph.add_make_tuple([], parent=block).out_port(0) + branch_port = graph.add_tag( + variants=[TupleType([])], tag=0, inp=unit, parent=block + ).out_port(0) + + # Finally, we have to add the block output. + outputs: Sequence[Variable] + if len(bb.successors) == 1: + # The easy case is if we don't branch: We just output all variables that are + # specified by the signature + [outputs] = bb.sig.output_rows + else: + # If we branch and the branches use the same variables, then we can use a + # regular output + first, *rest = bb.sig.output_rows + if all({v.name for v in first} == {v.name for v in r} for r in rest): + outputs = first + else: + # Otherwise, we have to output a TupleSum: We put all non-linear variables + # into the branch TupleSum and all linear variables in the normal output + # (since they are shared between all successors). This is in line with the + # ordering on variables which puts linear variables at the end. The only + # exception are return vars which must be outputted in order. + branch_port = choose_vars_for_tuple_sum( + graph=graph, + unit_sum=branch_port, + output_vars=[ + [v for v in row if not v.ty.linear or is_return_var(v.name)] + for row in bb.sig.output_rows + ], + dfg=dfg, + ) + outputs = [v for v in first if v.ty.linear and not is_return_var(v.name)] + + graph.add_output( + inputs=[branch_port] + [dfg[v.name].port for v in sort_vars(outputs)], + parent=block, + ) + return block + + +def insert_return_vars(cfg: CheckedCFG) -> None: + """Patches a CFG by annotating dummy return variables in the BB signatures. + + The statement compiler turns `return` statements into assignments of dummy variables + `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are + correctly outputted. + """ + return_vars = [ + Variable(return_var(i), ty, None, None) + for i, ty in enumerate(type_to_row(cfg.output_ty)) + ] + # Before patching, the exit BB shouldn't take any inputs + assert len(cfg.exit_bb.sig.input_row) == 0 + cfg.exit_bb.sig = Signature(return_vars, cfg.exit_bb.sig.output_rows) + # Also patch the predecessors + for pred in cfg.exit_bb.predecessors: + # The exit BB will be the only successor + assert len(pred.sig.output_rows) == 1 and len(pred.sig.output_rows[0]) == 0 + pred.sig = Signature(pred.sig.input_row, [return_vars]) + + +def choose_vars_for_tuple_sum( + graph: Hugr, unit_sum: OutPortV, output_vars: list[VarRow], dfg: DFContainer +) -> OutPortV: + """Selects an output based on a TupleSum. + + Given `unit_sum: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, + constructs a TupleSum value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. + """ + assert isinstance(unit_sum.ty, SumType) + assert len(unit_sum.ty.element_types) == len(output_vars) + tuples = [ + graph.add_make_tuple( + inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], + parent=dfg.node, + ).out_port(0) + for vs in output_vars + ] + tys = [t.ty for t in tuples] + conditional = graph.add_conditional( + cond_input=unit_sum, inputs=tuples, parent=dfg.node + ) + for i, ty in enumerate(tys): + case = graph.add_case(conditional) + inp = graph.add_input(output_tys=tys, parent=case).out_port(i) + tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) + graph.add_output(inputs=[tag], parent=case) + return conditional.add_out_port(SumType(tys)) + + +def compare_var(x: Variable, y: Variable) -> int: + """Defines a `<` order on variables. + + We use this to determine in which order variables are outputted from basic blocks. + We need to output linear variables at the end, so we do a lexicographic ordering of + linearity and name. The only exception are return vars which must be outputted in + order. + """ + if is_return_var(x.name) and is_return_var(y.name): + return -1 if x.name < y.name else 1 + return -1 if (x.ty.linear, x.name) < (y.ty.linear, y.name) else 1 + + +def sort_vars(row: VarRow) -> list[Variable]: + """Sorts a row of variables. + + This determines the order in which they are outputted from a BB. + """ + return sorted(row, key=functools.cmp_to_key(compare_var)) diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py new file mode 100644 index 00000000..2b6a4fb2 --- /dev/null +++ b/guppy/compiler/core.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Iterator + +from guppy.ast_util import AstNode +from guppy.checker.core import Variable, CallableVariable +from guppy.gtypes import FunctionType +from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr + + +@dataclass +class PortVariable(Variable): + """Represents a local variable in a dataflow graph. + + Local variables are associated with a port in the Hugr. + """ + + port: OutPortV + + def __init__( + self, + name: str, + port: OutPortV, + defined_at: Optional[AstNode], + used: Optional[AstNode] = None, + ) -> None: + super().__init__(name, port.ty, defined_at, used) + object.__setattr__(self, "port", port) + + +class CompiledVariable(ABC, Variable): + """Abstract base class for compiled global module-level variables.""" + + @abstractmethod + def load( + self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode + ) -> OutPortV: + """Loads the variable as a value into a local dataflow graph.""" + + +class CompiledFunction(CompiledVariable, CallableVariable): + """Abstract base class a global module-level function.""" + + ty: FunctionType + + @abstractmethod + def compile_call( + self, + args: list[OutPortV], + dfg: "DFContainer", + graph: Hugr, + globals: "CompiledGlobals", + node: AstNode, + ) -> list[OutPortV]: + """Compiles a call to the function.""" + + +CompiledGlobals = dict[str, CompiledVariable] +CompiledLocals = dict[str, PortVariable] + + +@dataclass +class DFContainer: + """A dataflow graph under construction. + + This class is passed through the entire compilation pipeline and stores the node + whose dataflow child-graph is currently being constructed as well as all live local + variables. Note that the variable map is mutated in-place and always reflects the + current compilation state. + """ + + node: DFContainingNode + locals: CompiledLocals + + def __getitem__(self, item: str) -> PortVariable: + return self.locals[item] + + def __setitem__(self, key: str, value: PortVariable) -> None: + self.locals[key] = value + + def __iter__(self) -> Iterator[PortVariable]: + return iter(self.locals.values()) + + def __contains__(self, item: str) -> bool: + return item in self.locals + + def __copy__(self) -> "DFContainer": + # Make a copy of the var map so that mutating the copy doesn't + # mutate our variable mapping + return DFContainer(self.node, self.locals.copy()) + + def get_var(self, name: str) -> Optional[PortVariable]: + return self.locals.get(name, None) + + +class CompilerBase(ABC): + """Base class for the Guppy compiler.""" + + graph: Hugr + globals: CompiledGlobals + + def __init__(self, graph: Hugr, globals: CompiledGlobals) -> None: + self.graph = graph + self.globals = globals + + +def return_var(n: int) -> str: + """Name of the dummy variable for the n-th return value of a function. + + During compilation, we treat return statements like assignments of dummy variables. + For example, the statement `return e0, e1, e2` is treated like `%ret0 = e0 ; %ret1 = + e1 ; %ret2 = e2`. This way, we can reuse our existing mechanism for passing of live + variables between basic blocks.""" + return f"%ret{n}" + + +def is_return_var(x: str) -> bool: + """Checks whether the given name is a dummy return variable.""" + return x.startswith("%ret") diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py new file mode 100644 index 00000000..5fdc378b --- /dev/null +++ b/guppy/compiler/expr_compiler.py @@ -0,0 +1,116 @@ +import ast +from typing import Any, Optional + +from guppy.ast_util import AstVisitor, get_type +from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction +from guppy.error import InternalGuppyError +from guppy.gtypes import FunctionType, type_to_row, BoolType +from guppy.hugr import ops, val +from guppy.hugr.hugr import OutPortV +from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall + + +class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): + """A compiler from Guppy expressions to Hugr.""" + + dfg: DFContainer + + def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: + """Compiles an expression and returns a single port holding the output value.""" + self.dfg = dfg + with self.graph.parent(dfg.node): + res = self.visit(expr) + return res + + def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: + """Compiles a row expression and returns a list of ports, one for each value in + the row. + + On Python-level, we treat tuples like rows on top-level. However, nested tuples + are treated like regular Guppy tuples. + """ + return [self.compile(e, dfg) for e in expr_to_row(expr)] + + def visit_Constant(self, node: ast.Constant) -> OutPortV: + if value := python_value_to_hugr(node.value): + const = self.graph.add_constant(value, get_type(node)).out_port(0) + return self.graph.add_load_constant(const).out_port(0) + raise InternalGuppyError("Unsupported constant expression in compiler") + + def visit_LocalName(self, node: LocalName) -> OutPortV: + return self.dfg[node.id].port + + def visit_GlobalName(self, node: GlobalName) -> OutPortV: + return self.globals[node.id].load(self.dfg, self.graph, self.globals, node) + + def visit_Name(self, node: ast.Name) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Tuple(self, node: ast.Tuple) -> OutPortV: + return self.graph.add_make_tuple( + inputs=[self.visit(e) for e in node.elts] + ).out_port(0) + + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: + """Groups function return values into a tuple""" + if len(returns) != 1: + return self.graph.add_make_tuple(inputs=returns).out_port(0) + return returns[0] + + def visit_LocalCall(self, node: LocalCall) -> OutPortV: + func = self.visit(node.func) + assert isinstance(func.ty, FunctionType) + + args = [self.visit(arg) for arg in node.args] + call = self.graph.add_indirect_call(func, args) + rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.returns)))] + return self._pack_returns(rets) + + def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: + func = self.globals[node.func.name] + assert isinstance(func, CompiledFunction) + + args = [self.visit(arg) for arg in node.args] + rets = func.compile_call(args, self.dfg, self.graph, self.globals, node) + return self._pack_returns(rets) + + def visit_Call(self, node: ast.Call) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: + # The only case that is not desugared by the type checker is the `not` operation + # since it is not implemented via a dunder method + if isinstance(node.op, ast.Not): + arg = self.visit(node.operand) + return self.graph.add_node( + ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] + ).add_out_port(BoolType()) + + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_BinOp(self, node: ast.BinOp) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Compare(self, node: ast.Compare) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + +def expr_to_row(expr: ast.expr) -> list[ast.expr]: + """Turns an expression into a row expressions by unpacking top-level tuples.""" + return expr.elts if isinstance(expr, ast.Tuple) else [expr] + + +def python_value_to_hugr(v: Any) -> Optional[val.Value]: + """Turns a Python value into a Hugr value. + + Returns None if the Python value cannot be represented in Guppy. + """ + from guppy.prelude._internal import int_value, bool_value, float_value + + if isinstance(v, bool): + return bool_value(v) + elif isinstance(v, int): + return int_value(v) + elif isinstance(v, float): + return float_value(v) + return None diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py new file mode 100644 index 00000000..4631f0db --- /dev/null +++ b/guppy/compiler/func_compiler.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass + +from guppy.ast_util import AstNode +from guppy.checker.func_checker import CheckedFunction, DefinedFunction +from guppy.compiler.cfg_compiler import compile_cfg +from guppy.compiler.core import ( + CompiledFunction, + CompiledGlobals, + DFContainer, + PortVariable, +) +from guppy.gtypes import type_to_row, FunctionType +from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode +from guppy.nodes import CheckedNestedFunctionDef + + +@dataclass +class CompiledFunctionDef(DefinedFunction, CompiledFunction): + node: DFContainingVNode + + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: + return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) + + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: + call = graph.add_call(self.node.out_port(0), args, dfg.node) + return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] + + +def compile_global_func_def( + func: CheckedFunction, + def_node: DFContainingVNode, + graph: Hugr, + globals: CompiledGlobals, +) -> CompiledFunctionDef: + """Compiles a top-level function definition to Hugr.""" + _, ports = graph.add_input_with_ports(list(func.ty.args), def_node) + cfg_node = graph.add_cfg(def_node, ports) + compile_cfg(func.cfg, graph, cfg_node, globals) + + # Add output node for the cfg + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) + + return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) + + +def compile_local_func_def( + func: CheckedNestedFunctionDef, + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, +) -> PortVariable: + """Compiles a local (nested) function definition to Hugr.""" + assert func.ty.arg_names is not None + + # Pick an order for the captured variables + captured = list(func.captured.values()) + + # Prepend captured variables to the function arguments + closure_ty = FunctionType( + [v.ty for v in captured] + list(func.ty.args), + func.ty.returns, + [v.name for v in captured] + list(func.ty.arg_names), + ) + + def_node = graph.add_def(closure_ty, dfg.node, func.name) + def_input, input_ports = graph.add_input_with_ports(list(closure_ty.args), def_node) + + # If we have captured variables and the body contains a recursive occurrence of + # the function itself, then we provide the partially applied function as a local + # variable + if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: + loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0) + partial = graph.add_partial( + loaded, [def_input.out_port(i) for i in range(len(captured))], def_node + ) + input_ports.append(partial.out_port(0)) + func.cfg.input_tys.append(func.ty) + else: + # Otherwise, we treat the function like a normal global variable + globals = globals | { + func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node) + } + + # Compile the CFG + cfg_node = graph.add_cfg(def_node, inputs=input_ports) + compile_cfg(func.cfg, graph, cfg_node, globals) + + # Add output node for the cfg + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) + + # Finally, load the function into the local data-flow graph + loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) + if len(captured) > 0: + loaded = graph.add_partial( + loaded, [dfg[v.name].port for v in captured], dfg.node + ).out_port(0) + + return PortVariable(func.name, loaded, func) diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py new file mode 100644 index 00000000..db656b6f --- /dev/null +++ b/guppy/compiler/stmt_compiler.py @@ -0,0 +1,96 @@ +import ast +from typing import Sequence + +from guppy.ast_util import AstVisitor +from guppy.checker.cfg_checker import CheckedBB +from guppy.compiler.core import ( + CompilerBase, + DFContainer, + CompiledGlobals, + PortVariable, + return_var, +) +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.error import InternalGuppyError +from guppy.gtypes import TupleType +from guppy.hugr.hugr import Hugr, OutPortV +from guppy.nodes import CheckedNestedFunctionDef + + +class StmtCompiler(CompilerBase, AstVisitor[None]): + """A compiler for Guppy statements to Hugr""" + + expr_compiler: ExprCompiler + + bb: CheckedBB + dfg: DFContainer + + def __init__(self, graph: Hugr, globals: CompiledGlobals): + super().__init__(graph, globals) + self.expr_compiler = ExprCompiler(graph, globals) + + def compile_stmts( + self, + stmts: Sequence[ast.stmt], + bb: CheckedBB, + dfg: DFContainer, + ) -> DFContainer: + """Compiles a list of basic statements into a dataflow node. + + Note that the `dfg` is mutated in-place. After compilation, the DFG will also + contain all variables that are assigned in the given list of statements. + """ + self.bb = bb + self.dfg = dfg + for s in stmts: + self.visit(s) + return self.dfg + + def _unpack_assign(self, lhs: ast.expr, port: OutPortV, node: ast.stmt) -> None: + """Updates the local DFG with assignments.""" + if isinstance(lhs, ast.Name): + x = lhs.id + self.dfg[x] = PortVariable(x, port, node) + elif isinstance(lhs, ast.Tuple): + unpack = self.graph.add_unpack_tuple(port, self.dfg.node) + for i, pat in enumerate(lhs.elts): + self._unpack_assign(pat, unpack.out_port(i), node) + else: + raise InternalGuppyError("Invalid assign pattern in compiler") + + def visit_Assign(self, node: ast.Assign) -> None: + [target] = node.targets + port = self.expr_compiler.compile(node.value, self.dfg) + self._unpack_assign(target, port, node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + assert node.value is not None + port = self.expr_compiler.compile(node.value, self.dfg) + self._unpack_assign(node.target, port, node) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Expr(self, node: ast.Expr) -> None: + self.expr_compiler.compile_row(node.value, self.dfg) + + def visit_Return(self, node: ast.Return) -> None: + # We turn returns into assignments of dummy variables, i.e. the statement + # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`. + if node.value is not None: + port = self.expr_compiler.compile(node.value, self.dfg) + if isinstance(port.ty, TupleType): + unpack = self.graph.add_unpack_tuple(port, self.dfg.node) + row = [unpack.out_port(i) for i in range(len(port.ty.element_types))] + else: + row = [port] + for i, port in enumerate(row): + name = return_var(i) + self.dfg[name] = PortVariable(name, port, node) + + def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: + from guppy.compiler.func_compiler import compile_local_func_def + + self.dfg[node.name] = compile_local_func_def( + node, self.dfg, self.graph, self.globals + ) diff --git a/guppy/compiler_base.py b/guppy/compiler_base.py deleted file mode 100644 index 02fd4651..00000000 --- a/guppy/compiler_base.py +++ /dev/null @@ -1,251 +0,0 @@ -import ast -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Iterator, Optional, Any, NamedTuple - -from guppy.ast_util import AstNode -from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType -from guppy.hugr.hugr import OutPortV, Hugr, DFContainingNode - - -ValueName = str -TypeName = str - - -@dataclass -class RawVariable: - """Class holding data associated with a variable. - - Besides the name and type, we also store an AST node where the variable was defined. - """ - - name: ValueName - ty: GuppyType - defined_at: Optional[AstNode] - - def __lt__(self, other: Any) -> bool: - # We define an ordering on variables that is used to determine in which order - # they are outputted from basic blocks. We need to output linear variables at - # the end, so we do a lexicographic ordering of linearity and name, exploiting - # the fact that `False < True` in Python. The only exception are return vars - # which must be outputted in order. - if not isinstance(other, Variable): - return NotImplemented - if is_return_var(self.name) and is_return_var(other.name): - return self.name < other.name - return (self.ty.linear, self.name) < (other.ty.linear, other.name) - - -@dataclass -class Variable(RawVariable): - """Represents a concrete variable during compilation. - - Compared to a `RawVariable`, each variable corresponds to a Hugr port. - """ - - port: OutPortV - used: Optional[AstNode] = None - - def __init__(self, name: str, port: OutPortV, defined_at: Optional[AstNode]): - super().__init__(name, port.ty, defined_at) - object.__setattr__(self, "port", port) - - -# A dictionary mapping names to live variables -VarMap = dict[ValueName, Variable] - - -@dataclass -class DFContainer: - """A dataflow graph under construction. - - This class is passed through the entire compilation pipeline and stores the node - whose dataflow child-graph is currently being constructed as well as all live - variables. Note that the variable map is mutated in-place and always reflects the - current compilation state. - """ - - node: DFContainingNode - variables: VarMap - - def __getitem__(self, item: str) -> Variable: - return self.variables[item] - - def __setitem__(self, key: str, value: Variable) -> None: - self.variables[key] = value - - def __iter__(self) -> Iterator[Variable]: - return iter(self.variables.values()) - - def __contains__(self, item: str) -> bool: - return item in self.variables - - def __copy__(self) -> "DFContainer": - # Make a copy of the var map so that mutating the copy doesn't - # mutate our variable mapping - return DFContainer(self.node, self.variables.copy()) - - def get_var(self, name: str) -> Optional[Variable]: - return self.variables.get(name, None) - - -@dataclass -class GlobalVariable(ABC, RawVariable): - """Represents a global module-level variable.""" - - @abstractmethod - def load( - self, graph: Hugr, parent: DFContainingNode, globals: "Globals", node: AstNode - ) -> OutPortV: - """Loads the global variable as a value into a local dataflow graph.""" - - -@dataclass -class GlobalFunction(GlobalVariable, ABC): - """Represents a global module-level function.""" - - ty: FunctionType - call_compiler: "CallCompiler" - - def compile_call( - self, - args: list[OutPortV], - dfg: DFContainer, - graph: Hugr, - globals: "Globals", - node: AstNode, - ) -> list[OutPortV]: - """Utility method that invokes the local `CallCompiler` to compile a function - call""" - self.call_compiler.setup(dfg, graph, globals, self, node) - return self.call_compiler.compile(args) - - def compile_call_raw( - self, - args: list[ast.expr], - dfg: DFContainer, - graph: Hugr, - globals: "Globals", - node: AstNode, - ) -> list[OutPortV]: - """Utility method that invokes the local `CallCompiler` to compile a function - call with raw argument AST nodes""" - self.call_compiler.setup(dfg, graph, globals, self, node) - return self.call_compiler.compile_raw(args) - - -class Globals(NamedTuple): - """Collection of names that are available on module-level. - - Separately stores names that are bound to values (i.e. module-level functions or - constants), to types, or to instance functions belonging to types. - """ - - values: dict[ValueName, GlobalVariable] - types: dict[TypeName, type[GuppyType]] - instance_funcs: dict[tuple[TypeName, ValueName], GlobalFunction] - - @staticmethod - def default() -> "Globals": - """Generates a `Globals` instance that is populated with all core types""" - tys: dict[str, type[GuppyType]] = { - FunctionType.name: FunctionType, - TupleType.name: TupleType, - SumType.name: SumType, - } - return Globals({}, tys, {}) - - def get_instance_func(self, ty: GuppyType, name: str) -> Optional[GlobalFunction]: - """Looks up an instance function with a given name for a type""" - return self.instance_funcs.get((ty.name, name), None) - - def __or__(self, other: "Globals") -> "Globals": - return Globals( - self.values | other.values, - self.types | other.types, - self.instance_funcs | other.instance_funcs, - ) - - def __ior__(self, other: "Globals") -> "Globals": - self.values.update(other.values) - self.types.update(other.types) - self.instance_funcs.update(other.instance_funcs) - return self - - -class CompilerBase(ABC): - """Base class for the Guppy compiler.""" - - graph: Hugr - globals: Globals - - def __init__(self, graph: Hugr, globals: Globals) -> None: - self.graph = graph - self.globals = globals - - -class CallCompiler(ABC): - """Abstract base class for function call compilers.""" - - dfg: DFContainer - graph: Hugr - globals: Globals - func: GlobalFunction - node: AstNode - - @abstractmethod - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - """Compiles a function call with the given argument ports. - - Returns a row of output ports that are returned by the function. - """ - ... - - def compile_raw(self, args: list[ast.expr]) -> list[OutPortV]: - """Compiles a function call with raw argument AST nodes. - - The default implementation invokes the `ExpressionCompiler` to first compile the - arguments and then calls out to `compile`. - """ - from guppy.expression import ExpressionCompiler - - expr_compiler = ExpressionCompiler(self.graph, self.globals) - return self.compile([expr_compiler.compile(a, self.dfg) for a in args]) - - @property - def parent(self) -> DFContainingNode: - """The parent node of the current dataflow graph""" - return self.dfg.node - - def setup( - self, - dfg: DFContainer, - graph: Hugr, - globals: "Globals", - func: GlobalFunction, - node: AstNode, - ) -> None: - """Initialises the parameters of the call compiler. - - Must be called before trying to compile. - """ - self.dfg = dfg - self.graph = graph - self.globals = globals - self.func = func - self.node = node - - -def return_var(n: int) -> str: - """Name of the dummy variable for the n-th return value of a function. - - During compilation, we treat return statements like assignments of dummy variables. - For example, the statement `return e0, e1, e2` is treated like `%ret0 = e0 ; %ret1 = - e1 ; %ret2 = e2`. This way, we can reuse our existing mechanism for passing of live - variables between basic blocks.""" - return f"%ret{n}" - - -def is_return_var(x: str) -> bool: - """Checks whether the given name is a dummy return variable.""" - return x.startswith("%ret") diff --git a/guppy/custom.py b/guppy/custom.py new file mode 100644 index 00000000..a80fd134 --- /dev/null +++ b/guppy/custom.py @@ -0,0 +1,235 @@ +import ast +from abc import ABC, abstractmethod +from typing import Optional + +from guppy.ast_util import AstNode, with_type, with_loc, get_type +from guppy.checker.core import Context, Globals +from guppy.checker.expr_checker import check_call, synthesize_call +from guppy.checker.func_checker import check_signature +from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals +from guppy.error import ( + GuppyError, + InternalGuppyError, + UnknownFunctionType, +) +from guppy.gtypes import GuppyType, FunctionType, type_to_row +from guppy.hugr import ops +from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode +from guppy.nodes import GlobalCall + + +class CustomFunction(CompiledFunction): + """A function whose type checking and compilation behaviour can be customised.""" + + defined_at: Optional[ast.FunctionDef] + + # Whether the function may be used as a higher-order value. This is only possible + # if a static type for the function is provided. + higher_order_value: bool + + call_checker: "CustomCallChecker" + call_compiler: "CustomCallCompiler" + + _ty: Optional[FunctionType] = None + _defined: dict[Node, DFContainingVNode] = {} + + def __init__( + self, + name: str, + defined_at: Optional[ast.FunctionDef], + compiler: "CustomCallCompiler", + checker: "CustomCallChecker", + higher_order_value: bool = True, + ty: Optional[FunctionType] = None, + ): + self.name = name + self.defined_at = defined_at + self.higher_order_value = higher_order_value + self.call_compiler = compiler + self.call_checker = checker + self.used = None + self._ty = ty + self._defined = {} + + @property # type: ignore + def ty(self) -> FunctionType: + if self._ty is None: + return UnknownFunctionType() + return self._ty + + @ty.setter + def ty(self, ty: FunctionType) -> None: + self._ty = ty + + def check_type(self, globals: Globals) -> None: + """Checks the type annotation on the signature declaration if provided.""" + if self._ty is not None: + return + + if self.defined_at is None: + if self.higher_order_value: + raise GuppyError( + f"Type signature for function `{self.name}` is required. " + "Alternatively, try passing `higher_order_value=False` on " + "definition." + ) + return + + try: + self._ty = check_signature(self.defined_at, globals) + except GuppyError as err: + # We can ignore the error if a custom call checker is provided and the + # function may not be used as a higher-order value + if self.call_checker is None or self.higher_order_value: + raise err + + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> ast.expr: + self.call_checker._setup(ctx, node, self) + return with_type(ty, with_loc(node, self.call_checker.check(args, ty))) + + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: "Context" + ) -> tuple[ast.expr, GuppyType]: + self.call_checker._setup(ctx, node, self) + new_node, ty = self.call_checker.synthesize(args) + return with_type(ty, with_loc(node, new_node)), ty + + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: + self.call_compiler._setup(dfg, graph, globals, node) + return self.call_compiler.compile(args) + + def load( + self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: + """Loads the custom function as a value into a local dataflow graph. + + This will place a `FunctionDef` node into the Hugr module if one for this + function doesn't already exist and loads it into the DFG. This operation will + fail if no function type has been specified. + """ + if self._ty is None: + raise GuppyError( + "This function does not support usage in a higher-order context", + node, + ) + + # Find the module node by walking up the hierarchy + module: Node = dfg.node + while not isinstance(module.op, ops.Module): + if module.parent is None: + raise InternalGuppyError( + "Encountered node that is not contained in a module." + ) + module = module.parent + + # If the function has not yet been loaded in this module, we first have to + # define it. We create a `FunctionDef` that takes some inputs, compiles a call + # to the function, and returns the results. + if module not in self._defined: + def_node = graph.add_def(self.ty, module, self.name) + _, inp_ports = graph.add_input_with_ports(list(self.ty.args), def_node) + returns = self.compile_call( + inp_ports, DFContainer(def_node, {}), graph, globals, node + ) + graph.add_output(returns, parent=def_node) + self._defined[module] = def_node + + # Finally, load the function into the local DFG + return graph.add_load_constant( + self._defined[module].out_port(0), dfg.node + ).out_port(0) + + +class CustomCallChecker(ABC): + """Protocol for custom function call type checkers.""" + + ctx: Context + node: AstNode + func: CustomFunction + + def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: + self.ctx = ctx + self.node = node + self.func = func + + @abstractmethod + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + """Checks the return value against a given type. + + Returns a (possibly) transformed and annotated AST node for the call. + """ + + @abstractmethod + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + """Synthesizes a type for the return value of a call. + + Also returns a (possibly) transformed and annotated argument list. + """ + + +class CustomCallCompiler(ABC): + """Protocol for custom function call compilers.""" + + dfg: DFContainer + graph: Hugr + globals: CompiledGlobals + node: AstNode + + def _setup( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> None: + self.dfg = dfg + self.graph = graph + self.globals = globals + self.node = node + + @abstractmethod + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + """Compiles a custom function call and returns the resulting ports.""" + + +class DefaultCallChecker(CustomCallChecker): + """Checks function calls by comparing to a type signature.""" + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + # Use default implementation from the expression checker + args = check_call(self.func.ty, args, ty, self.node, self.ctx) + return GlobalCall(func=self.func, args=args) + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + # Use default implementation from the expression checker + args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) + return GlobalCall(func=self.func, args=args), ty + + +class DefaultCallCompiler(CustomCallCompiler): + """Call compiler that invokes the regular expression compiler.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + raise NotImplementedError + + +class OpCompiler(CustomCallCompiler): + op: ops.OpType + + def __init__(self, op: ops.OpType) -> None: + self.op = op + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) + return_ty = get_type(self.node) + return [node.add_out_port(ty) for ty in type_to_row(return_ty)] + + +class NoopCompiler(CustomCallCompiler): + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + return args diff --git a/guppy/declared.py b/guppy/declared.py new file mode 100644 index 00000000..91b5283e --- /dev/null +++ b/guppy/declared.py @@ -0,0 +1,66 @@ +import ast +from dataclasses import dataclass +from typing import Optional + +from guppy.ast_util import AstNode, has_empty_body +from guppy.checker.core import Globals, Context +from guppy.checker.expr_checker import check_call, synthesize_call +from guppy.checker.func_checker import check_signature +from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals +from guppy.error import GuppyError +from guppy.gtypes import type_to_row, GuppyType +from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV +from guppy.nodes import GlobalCall + + +@dataclass +class DeclaredFunction(CompiledFunction): + """A user-declared function that compiles to a Hugr function declaration.""" + + node: Optional[VNode] = None + + @staticmethod + def from_ast( + func_def: ast.FunctionDef, name: str, globals: Globals + ) -> "DeclaredFunction": + ty = check_signature(func_def, globals) + if not has_empty_body(func_def): + raise GuppyError( + "Body of function declaration must be empty", func_def.body[0] + ) + return DeclaredFunction(name, ty, func_def, None) + + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> GlobalCall: + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) + + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[GlobalCall, GuppyType]: + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty + + def add_to_graph(self, graph: Hugr, parent: Node) -> None: + self.node = graph.add_declare(self.ty, parent, self.name) + + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: + assert self.node is not None + return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) + + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: + assert self.node is not None + call = graph.add_call(self.node.out_port(0), args, dfg.node) + return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/decorator.py b/guppy/decorator.py new file mode 100644 index 00000000..493ddcb2 --- /dev/null +++ b/guppy/decorator.py @@ -0,0 +1,195 @@ +import functools +from dataclasses import dataclass +from typing import Optional, Union, Callable, Any + +from guppy.ast_util import AstNode, has_empty_body +from guppy.custom import ( + CustomFunction, + OpCompiler, + DefaultCallChecker, + CustomCallCompiler, + CustomCallChecker, + DefaultCallCompiler, +) +from guppy.error import GuppyError, pretty_errors +from guppy.gtypes import GuppyType +from guppy.hugr import tys, ops +from guppy.hugr.hugr import Hugr +from guppy.module import GuppyModule, PyFunc, parse_py_func + +FuncDecorator = Callable[[PyFunc], PyFunc] +CustomFuncDecorator = Callable[[PyFunc], CustomFunction] +ClassDecorator = Callable[[type], type] + + +class _Guppy: + """Class for the `@guppy` decorator.""" + + # The current module + _module: Optional[GuppyModule] + + def __init__(self) -> None: + self._module = None + + def set_module(self, module: GuppyModule) -> None: + self._module = module + + @pretty_errors + def __call__( + self, arg: Union[PyFunc, GuppyModule] + ) -> Union[Optional[Hugr], FuncDecorator]: + """Decorator to annotate Python functions as Guppy code. + + Optionally, the `GuppyModule` in which the function should be placed can be passed + to the decorator. + """ + if isinstance(arg, GuppyModule): + + def dec(f: Callable[..., Any]) -> Callable[..., Any]: + assert isinstance(arg, GuppyModule) + arg.register_func_def(f) + + @functools.wraps(f) + def dummy(*args: Any, **kwargs: Any) -> Any: + raise GuppyError( + "Guppy functions can only be called in a Guppy context" + ) + + return dummy + + return dec + else: + module = self._module or GuppyModule("module") + module.register_func_def(arg) + return module.compile() + + @pretty_errors + def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator: + """Decorator to add new instance functions to a type.""" + module._instance_func_buffer = {} + + def dec(c: type) -> type: + module._register_buffered_instance_funcs(ty) + return c + + return dec + + @pretty_errors + def type( + self, + module: GuppyModule, + hugr_ty: tys.SimpleType, + name: str = "", + linear: bool = False, + ) -> ClassDecorator: + """Decorator to annotate a class definitions as Guppy types. + + Requires the static Hugr translation of the type. Additionally, the type can be + marked as linear. All `@guppy` annotated functions on the class are turned into + instance functions. + """ + module._instance_func_buffer = {} + + def dec(c: type) -> type: + _name = name or c.__name__ + + @dataclass(frozen=True) + class NewType(GuppyType): + name = _name + + @staticmethod + def build( + *args: GuppyType, node: Optional[AstNode] = None + ) -> "GuppyType": + # At the moment, custom types don't support type arguments. + if len(args) > 0: + raise GuppyError( + f"Type `{_name}` does not accept type parameters.", node + ) + return NewType() + + @property + def linear(self) -> bool: + return linear + + def to_hugr(self) -> tys.SimpleType: + return hugr_ty + + def __str__(self) -> str: + return _name + + NewType.__name__ = name + NewType.__qualname__ = _name + module.register_type(_name, NewType) + module._register_buffered_instance_funcs(NewType) + setattr(c, "_guppy_type", NewType) + return c + + return dec + + @pretty_errors + def custom( + self, + module: GuppyModule, + compiler: Optional[CustomCallCompiler] = None, + checker: Optional[CustomCallChecker] = None, + higher_order_value: bool = True, + name: str = "", + ) -> CustomFuncDecorator: + """Decorator to add custom typing or compilation behaviour to function decls. + + Optionally, usage of the function as a higher-order value can be disabled. In + that case, the function signature can be omitted if a custom call compiler is + provided. + """ + + def dec(f: PyFunc) -> CustomFunction: + func_ast = parse_py_func(f) + if not has_empty_body(func_ast): + raise GuppyError( + "Body of custom function declaration must be empty", + func_ast.body[0], + ) + call_checker = checker or DefaultCallChecker() + func = CustomFunction( + name or func_ast.name, + func_ast, + compiler or DefaultCallCompiler(), + call_checker, + higher_order_value, + ) + call_checker.func = func + module.register_custom_func(func) + return func + + return dec + + def hugr_op( + self, + module: GuppyModule, + op: ops.OpType, + checker: Optional[CustomCallChecker] = None, + higher_order_value: bool = True, + name: str = "", + ) -> CustomFuncDecorator: + """Decorator to annotate function declarations as HUGR ops.""" + return self.custom(module, OpCompiler(op), checker, higher_order_value, name) + + def declare(self, module: GuppyModule) -> FuncDecorator: + """Decorator to declare functions""" + + def dec(f: Callable[..., Any]) -> Callable[..., Any]: + module.register_func_decl(f) + + @functools.wraps(f) + def dummy(*args: Any, **kwargs: Any) -> Any: + raise GuppyError( + "Guppy functions can only be called in a Guppy context" + ) + + return dummy + + return dec + + +guppy = _Guppy() diff --git a/guppy/error.py b/guppy/error.py index 428c2238..b0ace141 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -1,11 +1,19 @@ +import ast +import functools +import sys +import textwrap from dataclasses import dataclass, field -from typing import Optional, Any, Sequence +from typing import Optional, Any, Sequence, Callable, TypeVar, cast -from guppy.ast_util import AstNode -from guppy.guppy_types import GuppyType +from guppy.ast_util import AstNode, get_line_offset, get_file, get_source +from guppy.gtypes import GuppyType, FunctionType from guppy.hugr.hugr import OutPortV, Node +# Whether the interpreter should exit when a Guppy error occurs +EXIT_ON_ERROR: bool = True + + @dataclass(frozen=True) class SourceLoc: """A source location associated with an AST node. @@ -14,13 +22,16 @@ class SourceLoc: inside the file. """ + file: str line: int col: int ast_node: Optional[AstNode] @staticmethod - def from_ast(node: AstNode, line_offset: int) -> "SourceLoc": - return SourceLoc(line_offset + node.lineno, node.col_offset, node) + def from_ast(node: AstNode) -> "SourceLoc": + file, line_offset = get_file(node), get_line_offset(node) + assert file is not None and line_offset is not None + return SourceLoc(file, line_offset + node.lineno - 1, node.col_offset, node) def __str__(self) -> str: return f"{self.line}:{self.col}" @@ -43,14 +54,14 @@ class GuppyError(Exception): # The message can also refer to AST locations using format placeholders `{0}`, `{1}` locs_in_msg: Sequence[Optional[AstNode]] = field(default_factory=list) - def get_msg(self, line_offset: int) -> str: + def get_msg(self) -> str: """Returns the message associated with this error. A line offset is needed to translate AST locations mentioned in the message into source locations in the actual file.""" return self.raw_msg.format( *( - SourceLoc.from_ast(loc, line_offset) if loc is not None else "???" + SourceLoc.from_ast(loc) if loc is not None else "???" for loc in self.locs_in_msg ) ) @@ -88,3 +99,80 @@ def node(self) -> Node: @property def offset(self) -> int: raise InternalGuppyError("Tried to access undefined Port") + + +class UnknownFunctionType(FunctionType): + """Dummy function type for custom functions without an expressible type. + + Raises an `InternalGuppyError` if one tries to access one of its members. + """ + + def __init__(self) -> None: + pass + + @property + def args(self) -> Sequence[GuppyType]: + raise InternalGuppyError("Tried to access unknown function type") + + @property + def returns(self) -> GuppyType: + raise InternalGuppyError("Tried to access unknown function type") + + @property + def args_names(self) -> Optional[Sequence[str]]: + raise InternalGuppyError("Tried to access unknown function type") + + +def format_source_location( + loc: ast.AST, + num_lines: int = 3, + indent: int = 4, +) -> str: + """Creates a pretty banner to show source locations for errors.""" + source, line_offset = get_source(loc), get_line_offset(loc) + assert source is not None and line_offset is not None + source_lines = source.splitlines(keepends=True) + end_col_offset = loc.end_col_offset or len(source_lines[loc.lineno]) + s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() + s += "\n" + loc.col_offset * " " + (end_col_offset - loc.col_offset) * "^" + s = textwrap.dedent(s).splitlines() + # Add line numbers + line_numbers = [ + str(line_offset + loc.lineno - i) + ":" for i in range(num_lines, 0, -1) + ] + longest = max(len(ln) for ln in line_numbers) + prefixes = [ln + " " * (longest - len(ln) + indent) for ln in line_numbers] + res = "".join(prefix + line + "\n" for prefix, line in zip(prefixes, s[:-1])) + res += (longest + indent) * " " + s[-1] + return res + + +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + + +def pretty_errors(f: FuncT) -> FuncT: + """Decorator to print custom error banners when a `GuppyError` occurs.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + try: + return f(*args, **kwargs) + except GuppyError as err: + # Reraise if we're missing a location + if not err.location: + raise err + loc = err.location + file, line_offset = get_file(loc), get_line_offset(loc) + assert file is not None and line_offset is not None + line = line_offset + loc.lineno - 1 + print( + f"Guppy compilation failed. Error in file {file}:{line}\n\n" + f"{format_source_location(loc)}\n" + f"{err.__class__.__name__}: {err.get_msg()}", + file=sys.stderr, + ) + if EXIT_ON_ERROR: + sys.exit(1) + return None + + return cast(FuncT, wrapped) diff --git a/guppy/expression.py b/guppy/expression.py deleted file mode 100644 index 44e3badd..00000000 --- a/guppy/expression.py +++ /dev/null @@ -1,310 +0,0 @@ -import ast -from typing import Any, Optional - -from guppy.ast_util import AstVisitor, AstNode -from guppy.compiler_base import ( - CompilerBase, - DFContainer, - GlobalFunction, - GlobalVariable, -) -from guppy.error import InternalGuppyError, GuppyTypeError, GuppyError -from guppy.guppy_types import FunctionType, GuppyType -from guppy.hugr import val, ops -from guppy.hugr.hugr import OutPortV - -# Mapping from unary AST op to dunder method and display name -unary_table: dict[type[AstNode], tuple[str, str]] = { - ast.UAdd: ("__pos__", "+"), - ast.USub: ("__neg__", "-"), - ast.Invert: ("__invert__", "~"), -} - -# Mapping from binary AST op to left dunder method, right dunder method and display name -binary_table: dict[type[AstNode], tuple[str, str, str]] = { - ast.Add: ("__add__", "__radd__", "+"), - ast.Sub: ("__sub__", "__rsub__", "-"), - ast.Mult: ("__mul__", "__rmul__", "*"), - ast.Div: ("__truediv__", "__rtruediv__", "/"), - ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), - ast.Mod: ("__mod__", "__rmod__", "%"), - ast.Pow: ("__pow__", "__rpow__", "**"), - ast.LShift: ("__lshift__", "__rlshift__", "<<"), - ast.RShift: ("__rshift__", "__rrshift__", ">>"), - ast.BitOr: ("__or__", "__ror__", "||"), - ast.BitXor: ("__xor__", "__rxor__", "^"), - ast.BitAnd: ("__and__", "__rand__", "&&"), - ast.MatMult: ("__matmul__", "__rmatmul__", "@"), - ast.Eq: ("__eq__", "__eq__", "=="), - ast.NotEq: ("__neq__", "__neq__", "!="), - ast.Lt: ("__lt__", "__gt__", "<"), - ast.LtE: ("__le__", "__ge__", "<="), - ast.Gt: ("__gt__", "__lt__", ">"), - ast.GtE: ("__ge__", "__le__", ">="), -} - - -def expr_to_row(expr: ast.expr) -> list[ast.expr]: - """Turns an expression into a row expressions by unpacking top-level tuples.""" - return expr.elts if isinstance(expr, ast.Tuple) else [expr] - - -class ExpressionCompiler(CompilerBase, AstVisitor[OutPortV]): - """A compiler from Python expressions to Hugr.""" - - dfg: DFContainer - - def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: - """Compiles an expression and returns a single port holding the output value.""" - self.dfg = dfg - with self.graph.parent(dfg.node): - res = self.visit(expr) - return res - - def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: - """Compiles a row expression and returns a list of ports, one for - each value in the row. - - On Python-level, we treat tuples like rows on top-level. However, - nested tuples are treated like regular Guppy tuples. - """ - return [self.compile(e, dfg) for e in expr_to_row(expr)] - - def _is_global_var(self, x: str) -> Optional[GlobalVariable]: - """Checks if the argument references a global variable. - - Returns the variable if it exists, otherwise `None`. - """ - if x in self.globals.values and x not in self.dfg: - return self.globals.values[x] - return None - - def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> Any: - raise GuppyError("Expression not supported", node) - - def visit_Constant(self, node: ast.Constant) -> OutPortV: - if val_ty := python_value_to_hugr(node.value): - const = self.graph.add_constant(*val_ty).out_port(0) - return self.graph.add_load_constant(const).out_port(0) - raise GuppyError("Unsupported constant expression", node) - - def visit_Name(self, node: ast.Name) -> OutPortV: - x = node.id - if x in self.dfg: - var = self.dfg[x] - if var.ty.linear and var.used is not None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - node, - [var.used], - ) - var.used = node - return self.dfg[x].port - elif x in self.globals.values: - return self.globals.values[x].load( - self.graph, self.dfg.node, self.globals, node - ) - raise InternalGuppyError( - f"Variable `{x}` is not defined in ExpressionCompiler. This should have " - f"been caught by program analysis!" - ) - - def visit_JoinedString(self, node: ast.JoinedStr) -> OutPortV: - raise GuppyError("Guppy does not support formatted strings", node) - - def visit_Tuple(self, node: ast.Tuple) -> OutPortV: - return self.graph.add_make_tuple( - inputs=[self.visit(e) for e in node.elts] - ).out_port(0) - - def visit_List(self, node: ast.List) -> OutPortV: - raise NotImplementedError() - - def visit_Set(self, node: ast.Set) -> OutPortV: - raise NotImplementedError() - - def visit_Dict(self, node: ast.Dict) -> OutPortV: - raise NotImplementedError() - - def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: - arg = self.visit(node.operand) - - # Special case for the `not` operation since it is not implemented via a dunder - # method or control-flow - if isinstance(node.op, ast.Not): - from guppy.prelude.builtin import BoolType - - func = self.globals.get_instance_func(arg.ty, "__bool__") - if func is None: - raise GuppyTypeError( - f"Expression of type `{arg.ty}` cannot be interpreted as a `bool`", - node.operand, - ) - [arg] = func.compile_call([arg], self.dfg, self.graph, self.globals, node) - return self.graph.add_node( - ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] - ).add_out_port(BoolType()) - - # Compile all other unary expressions by calling out to instance dunder methods - op, display_name = unary_table[node.op.__class__] - func = self.globals.get_instance_func(arg.ty, op) - if func is None: - raise GuppyTypeError( - f"Unary operator `{display_name}` not defined for argument of type " - f" `{arg.ty}`", - node.operand, - ) - [res] = func.compile_call([arg], self.dfg, self.graph, self.globals, node) - return res - - def _compile_binary( - self, left_expr: AstNode, right_expr: AstNode, op: AstNode, node: AstNode - ) -> OutPortV: - """Helper method to compile binary operators by calling out to dunder methods. - - For example, first try calling `__add__` on the left operand. If that fails, try - `__radd__` on the right operand. - """ - if op.__class__ not in binary_table: - raise GuppyError("This binary operation is not supported by Guppy.") - lop, rop, display_name = binary_table[op.__class__] - left, right = self.visit(left_expr), self.visit(right_expr) - - if func := self.globals.get_instance_func(left.ty, lop): - try: - [ret] = func.compile_call( - [left, right], self.dfg, self.graph, self.globals, node - ) - return ret - except GuppyError: - pass - - if func := self.globals.get_instance_func(right.ty, lop): - try: - [ret] = func.compile_call( - [left, right], self.dfg, self.graph, self.globals, node - ) - return ret - except GuppyError: - pass - - raise GuppyTypeError( - f"Binary operator `{display_name}` not defined for arguments of type " - f"`{left.ty}` and `{right.ty}`", - node, - ) - - def visit_BinOp(self, node: ast.BinOp) -> OutPortV: - return self._compile_binary(node.left, node.right, node.op, node) - - def visit_Compare(self, node: ast.Compare) -> OutPortV: - if len(node.comparators) != 1 or len(node.ops) != 1: - raise InternalGuppyError( - "BB contains chained comparison. Should have been removed during CFG " - "construction." - ) - left_expr, [op], [right_expr] = node.left, node.ops, node.comparators - return self._compile_binary(left_expr, right_expr, op, node) - - def visit_Call(self, node: ast.Call) -> OutPortV: - func = node.func - if len(node.keywords) > 0: - raise GuppyError( - "Argument passing by keyword is not supported", node.keywords[0] - ) - - # Special case for calls of global module-level functions. This also handles - # calls of extension functions. - if ( - isinstance(func, ast.Name) - and (f := self._is_global_var(func.id)) - and isinstance(f, GlobalFunction) - ): - returns = f.compile_call_raw( - node.args, self.dfg, self.graph, self.globals, node - ) - - # Otherwise, compile the function like any other expression - else: - port = self.visit(func) - args = [self.visit(arg) for arg in node.args] - if isinstance(port.ty, FunctionType): - type_check_call(port.ty, args, node) - call = self.graph.add_indirect_call(port, args) - returns = [call.out_port(i) for i in range(len(port.ty.returns))] - elif f := self.globals.get_instance_func(port.ty, "__call__"): - returns = f.compile_call(args, self.dfg, self.graph, self.globals, node) - else: - raise GuppyTypeError(f"Expected function type, got `{port.ty}`", func) - - # Group outputs into tuple - if len(returns) != 1: - return self.graph.add_make_tuple(inputs=returns).out_port(0) - return returns[0] - - def visit_NamedExpr(self, node: ast.NamedExpr) -> OutPortV: - raise InternalGuppyError( - "BB contains `NamedExpr`. Should have been removed during CFG" - f"construction: `{ast.unparse(node)}`" - ) - - def visit_BoolOp(self, node: ast.BoolOp) -> OutPortV: - raise InternalGuppyError( - "BB contains `BoolOp`. Should have been removed during CFG construction: " - f"`{ast.unparse(node)}`" - ) - - def visit_IfExp(self, node: ast.IfExp) -> OutPortV: - raise InternalGuppyError( - "BB contains `IfExp`. Should have been removed during CFG construction: " - f"`{ast.unparse(node)}`" - ) - - -def check_num_args(exp: int, act: int, node: AstNode) -> None: - """Checks that the correct number of arguments have been passed to a function.""" - if act < exp: - raise GuppyTypeError( - f"Not enough arguments passed (expected {exp}, got {act})", node - ) - if exp < act: - if isinstance(node, ast.Call): - raise GuppyTypeError("Unexpected argument", node.args[exp]) - raise GuppyTypeError( - f"Too many arguments passed (expected {exp}, got {act})", node - ) - - -def type_check_call(func_ty: FunctionType, args: list[OutPortV], node: AstNode) -> None: - """Type-checks the arguments for a function call.""" - check_num_args(len(func_ty.args), len(args), node) - for i, port in enumerate(args): - if port.ty != func_ty.args[i]: - raise GuppyTypeError( - f"Expected argument of type `{func_ty.args[i]}`, got `{port.ty}`", - node.args[i] if isinstance(node, ast.Call) else node, - ) - - -def python_value_to_hugr(v: Any) -> Optional[tuple[val.Value, GuppyType]]: - """Turns a Python value into a Hugr value together with its type. - - Returns None if the Python value cannot be represented in Guppy. - """ - from guppy.prelude.builtin import ( - IntType, - BoolType, - FloatType, - int_value, - bool_value, - float_value, - ) - - if isinstance(v, bool): - return bool_value(v), BoolType() - elif isinstance(v, int): - return int_value(v), IntType() - elif isinstance(v, float): - return float_value(v), FloatType() - return None diff --git a/guppy/extension.py b/guppy/extension.py deleted file mode 100644 index a19792a7..00000000 --- a/guppy/extension.py +++ /dev/null @@ -1,386 +0,0 @@ -import ast -import builtins -import inspect -import textwrap -from dataclasses import dataclass, field -from types import ModuleType -from typing import Optional, Callable, Any, Union, Sequence - -from guppy.ast_util import AstNode, is_empty_body -from guppy.compiler_base import ( - GlobalFunction, - Globals, - CallCompiler, - DFContainer, -) -from guppy.error import GuppyError, InternalGuppyError -from guppy.expression import type_check_call -from guppy.function import FunctionDefCompiler -from guppy.guppy_types import FunctionType, GuppyType -from guppy.hugr import ops, tys as tys -from guppy.hugr.hugr import Hugr, DFContainingNode, OutPortV, Node, DFContainingVNode - - -class ExtensionDefinitionError(Exception): - """Exception indicating a failure while defining an extension.""" - - def __init__(self, msg: str, extension: "GuppyExtension") -> None: - super().__init__( - f"Definition of extension `{extension.name}` is invalid: {msg}" - ) - - -@dataclass -class ExtensionFunction(GlobalFunction): - """A custom function to extend Guppy with functionality. - - Must be provided with a `CallCompiler` that handles compilation of calls to the - extension function. This allows for full flexibility in how extensions are compiled. - Additionally, it can be specified whether the function can be used as a value in a - higher-order context. - """ - - call_compiler: CallCompiler - higher_order_value: bool = True - - _defined: dict[Node, DFContainingVNode] = field(default_factory=dict, init=False) - - def load( - self, graph: Hugr, parent: DFContainingNode, globals: Globals, node: AstNode - ) -> OutPortV: - """Loads the extension function as a value into a local dataflow graph. - - This will place a `FunctionDef` node into the Hugr module if one for this - function doesn't already exist and loads it into the DFG. This operation will - fail if the extension function is marked as not supporting higher-order usage. - """ - if not self.higher_order_value: - raise GuppyError( - "This function does not support usage in a higher-order context", - node, - ) - - # Find the module node by walking up the hierarchy - module: Node = parent - while not isinstance(module.op, ops.Module): - if module.parent is None: - raise InternalGuppyError( - "Encountered node that is not contained in a module." - ) - module = module.parent - - # If the function has not yet been loaded in this module, we first have to - # define it. We create a `FunctionDef` that takes some inputs, compiles a call - # to the function, and returns the results. - if module not in self._defined: - def_node = graph.add_def(self.ty, module, self.name) - inp = graph.add_input(list(self.ty.args), parent=def_node) - returns = self.compile_call( - [inp.out_port(i) for i in range(len(self.ty.args))], - DFContainer(def_node, {}), - graph, - globals, - node, - ) - graph.add_output(returns, parent=def_node) - self._defined[module] = def_node - - # Finally, load the function into the local DFG - return graph.add_load_constant( - self._defined[module].out_port(0), parent - ).out_port(0) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - raise GuppyError("Tried to call Guppy function in a Python context") - - -class UntypedExtensionFunction(ExtensionFunction): - """An extension function that does not require a signature. - - As a result, functions like this cannot be used in a higher-order context. - """ - - def __init__( - self, name: str, defined_at: Optional[AstNode], call_compiler: CallCompiler - ) -> None: - self.name = name - self.defined_at = defined_at - self.higher_order = False - self.call_compiler = call_compiler - - @property # type: ignore - def ty(self) -> FunctionType: - raise InternalGuppyError( - "Tried to access signature from untyped extension function" - ) - - -class GuppyExtension: - """A Guppy extension. - - Consists of a collection of types, extension functions, and instance functions. - Note that extensions can also declare instance functions for types that are not - defined in this extension. - """ - - name: str - - # Globals for all new names defined by this extension - globals: Globals - - # Globals for this extension including core types and all dependencies - _all_globals: Globals - - def __init__(self, name: str, dependencies: Sequence[ModuleType]) -> None: - """Creates a new empty Guppy extension. - - If the extension uses types from other extensions (for example the `builtin.py` - extensions from the prelude), they have to be passed as dependencies. - """ - - self.name = name - self.globals = Globals({}, {}, {}) - self._all_globals = Globals.default() - - for module in dependencies: - exts = [ - obj - for obj in module.__dict__.values() - if isinstance(obj, GuppyExtension) - ] - if len(exts) == 0: - raise ExtensionDefinitionError( - f"Dependency module `{module.__name__}` does not contain a Guppy extension", - self, - ) - for ext in exts: - self._all_globals |= ext.globals - - def register_type(self, name: str, ty: type[GuppyType]) -> None: - """Registers an existing `GuppyType` subclass with this extension.""" - self.globals.types[name] = ty - self._all_globals.types[name] = ty - - def register_func(self, name: str, func: "ExtensionFunction") -> None: - """Registers an existing `ExtensionFunction` with this extension.""" - self.globals.values[name] = func - - def register_instance_func( - self, ty: type[GuppyType], name: str, func: "ExtensionFunction" - ) -> None: - """Registers an existing function as an instance function for the type with the - given name.""" - self.globals.instance_funcs[ty.name, name] = func - - def new_type( - self, name: str, hugr_repr: tys.SimpleType, linear: bool = False - ) -> type[GuppyType]: - """Creates a new type. - - Requires the static Hugr translation of the type. Additionally, the type can be - marked as linear. - """ - _name = name - - class NewType(GuppyType): - name = _name - - @staticmethod - def build( - *args: GuppyType, node: Union[ast.Name, ast.Subscript] - ) -> "GuppyType": - # At the moment, extension types don't support type arguments. - if len(args) > 0: - raise GuppyError( - f"Type `{name}` does not accept type parameters.", node - ) - return NewType() - - @property - def linear(self) -> bool: - return linear - - def to_hugr(self) -> tys.SimpleType: - return hugr_repr - - def __eq__(self, other: Any) -> bool: - return isinstance(other, NewType) - - def __str__(self) -> str: - return name - - NewType.__name__ = NewType.__qualname__ = name - self.register_type(name, NewType) - return NewType - - def type( - self, - hugr_repr: tys.SimpleType, - alias: Optional[str] = None, - linear: bool = False, - ) -> Callable[[type], type]: - """Class decorator to annotate a new Guppy type. - - Requires the static Hugr translation of the type. Additionally, the type can be - marked as linear and an alias can be provided to be used in place of the class - name. - """ - - def decorator(cls: type) -> type: - return self.new_type(alias or cls.__name__, hugr_repr, linear) - - return decorator - - def new_func( - self, - name: str, - call_compiler: CallCompiler, - signature: Optional[FunctionType] = None, - higher_order_value: bool = True, - instance: Optional[builtins.type[GuppyType]] = None, - ) -> ExtensionFunction: - """Creates a new extension function. - - Passing a `GuppyType` with `instance=...` marks the function as an instance - function for the given type. A type signature may be omitted if higher-order - usage of the function is disabled. - """ - func: ExtensionFunction - if signature is None: - if higher_order_value: - raise ExtensionDefinitionError( - "Signature may only be omitted if `higher_order=False` is set", - self, - ) - func = UntypedExtensionFunction(name, None, call_compiler) # TODO: Location - else: - func = ExtensionFunction( - name, - signature, - None, - call_compiler, - higher_order_value, # TODO: Location - ) - if instance is not None: - self.register_instance_func(instance, name, func) - else: - self.register_func(name, func) - - return func - - def func( - self, - call_compiler: CallCompiler, - alias: Optional[str] = None, - higher_order_value: bool = True, - instance: Optional[builtins.type[GuppyType]] = None, - ) -> Callable[[Callable[..., Any]], ExtensionFunction]: - """Decorator to annotate a new extension function. - - Passing a `GuppyType` with `instance=...` marks the function as an instance - function for the given type. The type signature is extracted from the Python - type annotations on the function. They may only be omitted if higher-order - usage of the function is disabled. - """ - - def decorator(f: Callable[..., Any]) -> ExtensionFunction: - func_ast, ty = self._parse_decl(f) - name = alias or func_ast.name - return self.new_func(name, call_compiler, ty, higher_order_value, instance) - - return decorator - - def _parse_decl( - self, f: Callable[..., Any] - ) -> tuple[ast.FunctionDef, Optional[FunctionType]]: - """Helper method to parse a function into an AST. - - Also returns the function type extracted from the type annotations if they are - provided. - """ - source = textwrap.dedent(inspect.getsource(f)) - func_ast = ast.parse(source).body[0] - if not isinstance(func_ast, ast.FunctionDef): - raise ExtensionDefinitionError( - "Only functions may be annotated using `@extension`", self - ) - if not is_empty_body(func_ast): - raise ExtensionDefinitionError( - "Body of declared extension functions must be empty", self - ) - # Return None if annotations are missing - if not func_ast.returns or not all( - arg.annotation for arg in func_ast.args.args - ): - return func_ast, None - - return func_ast, FunctionDefCompiler.validate_signature( - func_ast, self._all_globals - ) - - -class OpCompiler(CallCompiler): - """Compiler for calls that can be implemented via a single Hugr op. - - Performs type checking against the signature of the function and inserts the - specified op into the graph. - """ - - op: ops.OpType - signature: Optional[FunctionType] = None - - def __init__(self, op: ops.OpType, signature: Optional[FunctionType] = None): - self.op = op - self.signature = signature - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - func_ty = self.signature or self.func.ty - type_check_call(func_ty, args, self.node) - leaf = self.graph.add_node(self.op.copy(), inputs=args, parent=self.parent) - return [leaf.add_out_port(ty) for ty in func_ty.returns] - - -class IdOpCompiler(CallCompiler): - """Compiler for calls that are no-ops. - - Compiles a call by directly returning the arguments. Type checking can be disabled - by passing `type_check=False`. - """ - - type_check: bool - - def __init__(self, type_check: bool = True): - self.type_check = type_check - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - if self.type_check: - func_ty = self.func.ty - type_check_call(func_ty, args, self.node) - return args - - -class Reversed(CallCompiler): - """Call compiler that reverses the arguments and calls out to a different compiler. - - Useful to implement the right-hand version of arithmetic functions, e.g. `__radd__`. - """ - - cc: CallCompiler - - def __init__(self, cc: CallCompiler): - self.cc = cc - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - self.cc.setup(self.dfg, self.graph, self.globals, self.func, self.node) - return self.cc.compile(list(reversed(args))) - - -class NotImplementedCompiler(CallCompiler): - """Call compiler that raises an error when the function is called. - - Should be used to inform users that a function that would normally be available in - Python is not yet implemented in Guppy. - """ - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - raise GuppyError("Operation is not yet implemented", self.node) diff --git a/guppy/function.py b/guppy/function.py deleted file mode 100644 index 33fae04c..00000000 --- a/guppy/function.py +++ /dev/null @@ -1,240 +0,0 @@ -import ast - -from guppy.ast_util import return_nodes_in_ast, AstNode -from guppy.cfg.bb import BB, NestedFunctionDef -from guppy.cfg.builder import CFGBuilder -from guppy.compiler_base import ( - CompilerBase, - RawVariable, - DFContainer, - Globals, - GlobalFunction, - CallCompiler, -) -from guppy.error import GuppyError -from guppy.expression import type_check_call -from guppy.guppy_types import ( - FunctionType, - type_row_from_ast, - type_from_ast, -) -from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode, DFContainingNode - - -class DefinedFunction(GlobalFunction): - ty: FunctionType - defined_at: ast.FunctionDef - port: OutPortV - - def __init__(self, name: str, port: OutPortV, defined_at: ast.FunctionDef): - assert isinstance(port.ty, FunctionType) - super().__init__(name, port.ty, defined_at, self.DefCallCompiler()) - self.port = port - - def load( - self, graph: Hugr, parent: DFContainingNode, globals: Globals, node: AstNode - ) -> OutPortV: - """Loads the function as a value into a local dataflow graph.""" - return graph.add_load_constant(self.port, parent).out_port(0) - - class DefCallCompiler(CallCompiler): - """Compiler for calls to defined functions.""" - - func: "DefinedFunction" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # Defined functions can be called using a regular direct call op - type_check_call(self.func.ty, args, self.node) - call = self.graph.add_call(self.func.port, args, self.parent) - return [call.out_port(i) for i in range(len(self.func.ty.returns))] - - -class FunctionDefCompiler(CompilerBase): - cfg_builder: CFGBuilder - - def __init__(self, graph: Hugr, globals: Globals): - super().__init__(graph, globals) - self.cfg_builder = CFGBuilder() - - @staticmethod - def validate_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: - """Checks the signature of a function definition and returns the corresponding - Guppy type.""" - if len(func_def.args.posonlyargs) != 0: - raise GuppyError( - "Positional-only parameters not supported", func_def.args.posonlyargs[0] - ) - if len(func_def.args.kwonlyargs) != 0: - raise GuppyError( - "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] - ) - if func_def.args.vararg is not None: - raise GuppyError("*args not supported", func_def.args.vararg) - if func_def.args.kwarg is not None: - raise GuppyError("**kwargs not supported", func_def.args.kwarg) - if func_def.returns is None: - # TODO: Error location is incorrect - if all(r.value is None for r in return_nodes_in_ast(func_def)): - raise GuppyError( - "Return type must be annotated. Try adding a `-> None` annotation.", - func_def, - ) - raise GuppyError("Return type must be annotated", func_def) - - arg_tys = [] - arg_names = [] - for i, arg in enumerate(func_def.args.args): - if arg.annotation is None: - raise GuppyError("Argument type must be annotated", arg) - ty = type_from_ast(arg.annotation, globals) - arg_tys.append(ty) - arg_names.append(arg.arg) - - ret_type_row = type_row_from_ast(func_def.returns, globals) - return FunctionType(arg_tys, ret_type_row.tys, arg_names) - - def compile_global( - self, - func_def: ast.FunctionDef, - def_node: DFContainingVNode, - ) -> DefinedFunction: - """Compiles a top-level function definition.""" - func_ty = self.validate_signature(func_def, self.globals) - args = func_def.args.args - - cfg = self.cfg_builder.build(func_def.body, len(func_ty.returns), self.globals) - - def_input = self.graph.add_input(parent=def_node) - cfg_node = self.graph.add_cfg( - def_node, inputs=[def_input.add_out_port(ty) for ty in func_ty.args] - ) - assert func_ty.arg_names is not None - input_sig = [ - RawVariable(x, ty, loc) - for x, ty, loc in zip(func_ty.arg_names, func_ty.args, args) - ] - cfg.compile( - self.graph, - input_sig, - list(func_ty.returns), - cfg_node, - self.globals, - ) - - # Add final output node for the def block - self.graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in func_ty.returns], - parent=def_node, - ) - - return DefinedFunction(func_def.name, def_node.out_port(0), func_def) - - def compile_local( - self, - func_def: NestedFunctionDef, - dfg: DFContainer, - bb: BB, - ) -> DefinedFunction: - """Compiles a local (nested) function definition.""" - func_ty = self.validate_signature(func_def, self.globals) - args = func_def.args.args - assert func_ty.arg_names is not None - - # We've already computed the CFG for this function while computing the CFG of - # the enclosing function - cfg = func_def.cfg - - # Find captured variables - parent_cfg = bb.cfg - def_ass_before = set(func_ty.arg_names) | dfg.variables.keys() - maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] - cfg.analyze(len(func_ty.returns), def_ass_before, maybe_ass_before) - captured = [ - dfg[x] - for x in cfg.live_before[cfg.entry_bb] - if x not in func_ty.arg_names and x in dfg - ] - - # Captured variables may not be linear - for v in captured: - if v.ty.linear: - x = v.name - using_bb = cfg.live_before[cfg.entry_bb][x] - raise GuppyError( - f"Variable `{x}` with linear type `{v.ty}` may not be used here " - f"because it was defined in an outer scope (at {{0}})", - using_bb.vars.used[x], - [v.defined_at], - ) - - # Captured variables may never be assigned to - for bb in cfg.bbs: - for v in captured: - x = v.name - if x in bb.vars.assigned: - raise GuppyError( - f"Variable `{x}` defined in an outer scope (at {{0}}) may not " - f"be assigned to", - bb.vars.assigned[x], - [v.defined_at], - ) - - # Prepend captured variables to the function arguments - closure_ty = FunctionType( - [v.ty for v in captured] + list(func_ty.args), - func_ty.returns, - [v.name for v in captured] + list(func_ty.arg_names), - ) - - def_node = self.graph.add_def(closure_ty, dfg.node, func_def.name) - def_input = self.graph.add_input(parent=def_node) - input_ports = [def_input.add_out_port(ty) for ty in closure_ty.args] - input_row = captured + [ - RawVariable(x, ty, loc) - for x, ty, loc in zip(func_ty.arg_names, func_ty.args, args) - ] - - # If we have captured variables and the body contains a recursive occurrence of - # the function itself, then we pass a version of the function with applied - # captured arguments as an extra argument. - if len(captured) > 0 and func_def.name in cfg.live_before[cfg.entry_bb]: - loaded = self.graph.add_load_constant(def_node.out_port(0), parent=def_node) - partial = self.graph.add_partial( - loaded.out_port(0), args=input_ports[: len(captured)], parent=def_node - ) - input_ports += [partial.out_port(0)] - input_row += [RawVariable(func_def.name, func_ty, func_def)] - global_values = self.globals.values - # Otherwise, we can treat the function like a normal global variable - else: - global_values = self.globals.values | { - func_def.name: DefinedFunction( - func_def.name, def_node.out_port(0), func_def - ) - } - globals = Globals( - global_values, self.globals.types, self.globals.instance_funcs - ) - - cfg_node = self.graph.add_cfg(def_node, inputs=input_ports) - cfg.compile(self.graph, input_row, list(func_ty.returns), cfg_node, globals) - - # Add final output node for the def block - self.graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in func_ty.returns], - parent=def_node, - ) - - # Finally, add partial application node to supply the captured arguments - loaded = self.graph.add_load_constant(def_node.out_port(0), parent=dfg.node) - if len(captured) > 0: - # TODO: We can probably get rid of the load here once we have a resource - # that supports partial application, instead of using a dummy Op here. - partial = self.graph.add_partial( - loaded.out_port(0), args=[v.port for v in captured], parent=dfg.node - ) - port = partial.out_port(0) - else: - port = loaded.out_port(0) - - return DefinedFunction(func_def.name, port, func_def) diff --git a/guppy/guppy_types.py b/guppy/gtypes.py similarity index 63% rename from guppy/guppy_types.py rename to guppy/gtypes.py index 808fd576..4d98e29b 100644 --- a/guppy/guppy_types.py +++ b/guppy/gtypes.py @@ -1,13 +1,13 @@ import ast from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING, Union +from typing import Optional, Sequence, TYPE_CHECKING import guppy.hugr.tys as tys from guppy.ast_util import AstNode, set_location_from if TYPE_CHECKING: - from guppy.compiler_base import Globals + from guppy.checker.core import Globals class GuppyType(ABC): @@ -20,7 +20,7 @@ class GuppyType(ABC): @staticmethod @abstractmethod - def build(*args: "GuppyType", node: Union[ast.Name, ast.Subscript]) -> "GuppyType": + def build(*args: "GuppyType", node: Optional[AstNode] = None) -> "GuppyType": pass @property @@ -33,23 +33,10 @@ def to_hugr(self) -> tys.SimpleType: pass -@dataclass(frozen=True) -class TypeRow: - tys: Sequence[GuppyType] - - def __str__(self) -> str: - if len(self.tys) == 0: - return "None" - elif len(self.tys) == 1: - return str(self.tys[0]) - else: - return f"({', '.join(str(e) for e in self.tys)})" - - @dataclass(frozen=True) class FunctionType(GuppyType): args: Sequence[GuppyType] - returns: Sequence[GuppyType] + returns: GuppyType arg_names: Optional[Sequence[str]] = field( default=None, compare=False, # Argument names are not taken into account for type equality @@ -59,17 +46,21 @@ class FunctionType(GuppyType): linear = False def __str__(self) -> str: - return f"{TypeRow(self.args)} -> {TypeRow(self.returns)}" + if len(self.args) == 1: + [arg] = self.args + return f"{arg} -> {self.returns}" + else: + return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: # Function types cannot be constructed using `build`. The type parsing code # has a special case for function types. raise NotImplementedError() def to_hugr(self) -> tys.SimpleType: ins = [t.to_hugr() for t in self.args] - outs = [t.to_hugr() for t in self.returns] + outs = [t.to_hugr() for t in type_to_row(self.returns)] func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) return tys.PolyFuncType(params=[], body=func_ty) @@ -81,7 +72,12 @@ class TupleType(GuppyType): name: str = "tuple" @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + from guppy.error import GuppyError + + # TODO: Parse empty tuples via `tuple[()]` + if len(args) == 0: + raise GuppyError("Tuple type requires generic type arguments", node) return TupleType(list(args)) def __str__(self) -> str: @@ -101,7 +97,7 @@ class SumType(GuppyType): element_types: Sequence[GuppyType] @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: # Sum types cannot be parsed and constructed using `build` since they cannot be # written by the user raise NotImplementedError() @@ -122,9 +118,54 @@ def to_hugr(self) -> tys.SimpleType: return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) +@dataclass(frozen=True) +class NoneType(GuppyType): + name: str = "None" + linear: bool = False + + @staticmethod + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + if len(args) > 0: + from guppy.error import GuppyError + + raise GuppyError("Type `None` is not generic", node) + return NoneType() + + def __str__(self) -> str: + return "None" + + def to_hugr(self) -> tys.SimpleType: + return tys.Tuple(inner=[]) + + +@dataclass(frozen=True) +class BoolType(SumType): + """The type of booleans.""" + + linear = False + name = "bool" + + def __init__(self) -> None: + # Hugr bools are encoded as Sum((), ()) + super().__init__([TupleType([]), TupleType([])]) + + @staticmethod + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + if len(args) > 0: + from guppy.error import GuppyError + + raise GuppyError("Type `bool` is not generic", node) + return BoolType() + + def __str__(self) -> str: + return "bool" + + def _lookup_type(node: AstNode, globals: "Globals") -> Optional[type[GuppyType]]: if isinstance(node, ast.Name) and node.id in globals.types: return globals.types[node.id] + if isinstance(node, ast.Constant) and node.value is None: + return NoneType if ( isinstance(node, ast.Constant) and isinstance(node.value, str) @@ -158,23 +199,42 @@ def type_from_ast(node: AstNode, globals: "Globals") -> GuppyType: if isinstance(func_args, ast.List): return FunctionType( [type_from_ast(a, globals) for a in func_args.elts], - type_row_from_ast(ret, globals).tys, + type_from_ast(ret, globals), ) from guppy.error import GuppyError raise GuppyError("Not a valid Guppy type", node) -def type_row_from_ast(node: ast.expr, globals: "Globals") -> TypeRow: +def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[GuppyType]: """Turns an AST expression into a Guppy type row. This is needed to interpret the return type annotation of functions. """ # The return type `-> None` is represented in the ast as `ast.Constant(value=None)` if isinstance(node, ast.Constant) and node.value is None: - return TypeRow([]) + return [] ty = type_from_ast(node, globals) if isinstance(ty, TupleType): - return TypeRow(ty.element_types) + return ty.element_types else: - return TypeRow([ty]) + return [ty] + + +def row_to_type(row: Sequence[GuppyType]) -> GuppyType: + """Turns a row of types into a single type by packing into a tuple.""" + if len(row) == 0: + return NoneType() + elif len(row) == 1: + return row[0] + else: + return TupleType(row) + + +def type_to_row(ty: GuppyType) -> Sequence[GuppyType]: + """Turns a type into a row of types by unpacking top-level tuples.""" + if isinstance(ty, NoneType): + return [] + if isinstance(ty, TupleType): + return ty.element_types + return [ty] diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 8baf350d..96a647f1 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -3,16 +3,18 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Optional, Iterator, Tuple, Any +from typing import Optional, Iterator, Tuple, Any, Sequence from dataclasses import field, dataclass import guppy.hugr.ops as ops import guppy.hugr.raw as raw -from guppy.guppy_types import ( +from guppy.gtypes import ( GuppyType, TupleType, FunctionType, SumType, + type_to_row, + row_to_type, ) from guppy.hugr import val @@ -373,6 +375,14 @@ def add_input( parent.input_child = node return node + def add_input_with_ports( + self, output_tys: Sequence[GuppyType], parent: Optional[Node] = None + ) -> tuple[VNode, list[OutPortV]]: + """Adds an `Input` node to the graph.""" + node = self.add_input(None, parent) + ports = [node.add_out_port(ty) for ty in output_tys] + return node, ports + def add_output( self, inputs: Optional[list[OutPortV]] = None, @@ -483,7 +493,11 @@ def add_call( """Adds a `Call` node to the graph.""" assert isinstance(def_port.ty, FunctionType) return self.add_node( - ops.Call(), None, list(def_port.ty.returns), parent, args + [def_port] + ops.Call(), + None, + list(type_to_row(def_port.ty.returns)), + parent, + args + [def_port], ) def add_indirect_call( @@ -494,7 +508,7 @@ def add_indirect_call( return self.add_node( ops.CallIndirect(), None, - list(fun_port.ty.returns), + list(type_to_row(fun_port.ty.returns)), parent, [fun_port] + args, ) @@ -628,7 +642,9 @@ def remove_dummy_nodes(self) -> "Hugr": for n in list(self.nodes()): if isinstance(n, VNode) and isinstance(n.op, ops.DummyOp): name = n.op.name - fun_ty = FunctionType(list(n.in_port_types), list(n.out_port_types)) + fun_ty = FunctionType( + list(n.in_port_types), row_to_type(n.out_port_types) + ) if name in used_names: used_names[name] += 1 name = f"{name}${used_names[name]}" diff --git a/guppy/module.py b/guppy/module.py new file mode 100644 index 00000000..e07a15ac --- /dev/null +++ b/guppy/module.py @@ -0,0 +1,222 @@ +import ast +import inspect +import textwrap +from types import ModuleType + +from typing import Callable, Any, Optional, Union + +from guppy.ast_util import annotate_location, AstNode +from guppy.checker.core import Globals, qualified_name +from guppy.checker.func_checker import DefinedFunction, check_global_func_def +from guppy.compiler.core import CompiledGlobals +from guppy.compiler.func_compiler import compile_global_func_def, CompiledFunctionDef +from guppy.custom import CustomFunction +from guppy.declared import DeclaredFunction +from guppy.error import GuppyError, pretty_errors +from guppy.gtypes import GuppyType +from guppy.hugr.hugr import Hugr + +PyFunc = Callable[..., Any] + + +class GuppyModule: + """A Guppy module that may contain function and type definitions.""" + + name: str + + # Whether the module has already been compiled + _compiled: bool + + # Globals from imported modules + _imported_globals: Globals + _imported_compiled_globals: CompiledGlobals + + # Globals for functions and types defined in this module. Only gets populated during + # compilation + _globals: Globals + _compiled_globals: CompiledGlobals + + # Mappings of functions defined in this module + _func_defs: dict[str, ast.FunctionDef] + _func_decls: dict[str, ast.FunctionDef] + _custom_funcs: dict[str, CustomFunction] + + # When `_instance_buffer` is not `None`, then all registered functions will be + # buffered in this list. They only get properly registered, once + # `_register_buffered_instance_funcs` is called. This way, we can associate + _instance_func_buffer: Optional[dict[str, Union[PyFunc, CustomFunction]]] + + def __init__(self, name: str, import_builtins: bool = True): + self.name = name + self._globals = Globals({}, {}) + self._compiled_globals = {} + self._imported_globals = Globals.default() + self._imported_compiled_globals = {} + self._func_defs = {} + self._func_decls = {} + self._custom_funcs = {} + self._compiled = False + self._instance_func_buffer = None + + # Import builtin module + if import_builtins: + import guppy.prelude.builtins as builtins + + self.load(builtins) + + def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: + """Imports another Guppy module.""" + self._check_not_yet_compiled() + if isinstance(m, GuppyModule): + # Compile module if it isn't compiled yet + if not m.compiled: + m.compile() + + # For now, we can only import custom functions + if any( + not isinstance(v, CustomFunction) for v in m._compiled_globals.values() + ): + raise GuppyError( + "Importing modules with defined functions is not supported yet" + ) + + self._imported_globals |= m._globals + self._imported_compiled_globals |= m._compiled_globals + else: + for val in m.__dict__.values(): + if isinstance(val, GuppyModule): + self.load(val) + + def register_func_def( + self, f: PyFunc, instance: Optional[type[GuppyType]] = None + ) -> None: + """Registers a Python function definition as belonging to this Guppy module.""" + self._check_not_yet_compiled() + func_ast = parse_py_func(f) + if self._instance_func_buffer is not None: + self._instance_func_buffer[func_ast.name] = f + else: + name = ( + qualified_name(instance, func_ast.name) if instance else func_ast.name + ) + self._check_name_available(name, func_ast) + self._func_defs[name] = func_ast + + def register_func_decl(self, f: PyFunc) -> None: + """Registers a Python function declaration as belonging to this Guppy module.""" + self._check_not_yet_compiled() + func_ast = parse_py_func(f) + self._check_name_available(func_ast.name, func_ast) + self._func_decls[func_ast.name] = func_ast + + def register_custom_func( + self, func: CustomFunction, instance: Optional[type[GuppyType]] = None + ) -> None: + """Registers a custom function as belonging to this Guppy module.""" + self._check_not_yet_compiled() + if self._instance_func_buffer is not None: + self._instance_func_buffer[func.name] = func + else: + if instance: + func.name = qualified_name(instance, func.name) + self._check_name_available(func.name, func.defined_at) + self._custom_funcs[func.name] = func + + def register_type(self, name: str, ty: type[GuppyType]) -> None: + """Registers an existing Guppy type as belonging to this Guppy module.""" + self._globals.types[name] = ty + + def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: + assert self._instance_func_buffer is not None + buffer = self._instance_func_buffer + self._instance_func_buffer = None + for name, f in buffer.items(): + if isinstance(f, CustomFunction): + self.register_custom_func(f, instance) + else: + self.register_func_def(f, instance) + + @property + def compiled(self) -> bool: + return self._compiled + + @pretty_errors + def compile(self) -> Optional[Hugr]: + """Compiles the module and returns the final Hugr.""" + if self.compiled: + raise GuppyError("Module has already been compiled") + + # Prepare globals for type checking + for func in self._custom_funcs.values(): + func.check_type(self._imported_globals | self._globals) + defined_funcs = { + x: DefinedFunction.from_ast(f, x, self._imported_globals | self._globals) + for x, f in self._func_defs.items() + } + declared_funcs = { + x: DeclaredFunction.from_ast(f, x, self._imported_globals | self._globals) + for x, f in self._func_decls.items() + } + self._globals.values.update(self._custom_funcs) + self._globals.values.update(declared_funcs) + self._globals.values.update(defined_funcs) + + # Type check function definitions + checked = { + x: check_global_func_def(f, self._imported_globals | self._globals) + for x, f in defined_funcs.items() + } + + # Add declared functions to the graph + graph = Hugr(self.name) + module_node = graph.set_root_name(self.name) + for f in declared_funcs.values(): + f.add_to_graph(graph, module_node) + + # Prepare `FunctionDef` nodes for all function definitions + def_nodes = {x: graph.add_def(f.ty, module_node, x) for x, f in checked.items()} + self._compiled_globals |= ( + self._custom_funcs + | declared_funcs + | { + x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) + for x, f in checked.items() + } + ) + + # Compile function definitions to Hugr + for x, f in checked.items(): + compile_global_func_def( + f, + def_nodes[x], + graph, + self._imported_compiled_globals | self._compiled_globals, + ) + + self._compiled = True + return graph + + def _check_not_yet_compiled(self) -> None: + if self._compiled: + raise GuppyError(f"The module `{self.name}` has already been compiled") + + def _check_name_available(self, name: str, node: Optional[AstNode]) -> None: + if name in self._func_defs or name in self._custom_funcs: + raise GuppyError( + f"Module `{self.name}` already contains a function named `{name}`", + node, + ) + + +def parse_py_func(f: PyFunc) -> ast.FunctionDef: + source_lines, line_offset = inspect.getsourcelines(f) + source = "".join(source_lines) # Lines already have trailing \n's + source = textwrap.dedent(source) + func_ast = ast.parse(source).body[0] + file = inspect.getsourcefile(f) + if file is None: + raise GuppyError("Couldn't determine source file for function") + annotate_location(func_ast, source, file, line_offset) + if not isinstance(func_ast, ast.FunctionDef): + raise GuppyError("Expected a function definition", func_ast) + return func_ast diff --git a/guppy/nodes.py b/guppy/nodes.py new file mode 100644 index 00000000..dfd02349 --- /dev/null +++ b/guppy/nodes.py @@ -0,0 +1,78 @@ +"""Custom AST nodes used by Guppy""" + +import ast +from typing import TYPE_CHECKING, Any, Mapping + +from guppy.gtypes import FunctionType + +if TYPE_CHECKING: + from guppy.cfg.cfg import CFG + from guppy.checker.core import Variable, CallableVariable + from guppy.checker.cfg_checker import CheckedCFG + + +class LocalName(ast.expr): + id: str + + _fields = ("id",) + + +class GlobalName(ast.expr): + id: str + value: "Variable" + + _fields = ( + "id", + "value", + ) + + +class LocalCall(ast.expr): + func: ast.expr + args: list[ast.expr] + + _fields = ( + "func", + "args", + ) + + +class GlobalCall(ast.expr): + func: "CallableVariable" + args: list[ast.expr] + + # Later: Inferred type args + + _fields = ( + "func", + "args", + ) + + +class NestedFunctionDef(ast.FunctionDef): + cfg: "CFG" + ty: FunctionType + + def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.cfg = cfg + self.ty = ty + + +class CheckedNestedFunctionDef(ast.FunctionDef): + cfg: "CheckedCFG" + ty: FunctionType + captured: Mapping[str, "Variable"] + + def __init__( + self, + cfg: "CheckedCFG", + ty: FunctionType, + captured: Mapping[str, "Variable"], + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.cfg = cfg + self.ty = ty + self.captured = captured diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py new file mode 100644 index 00000000..d8a64c32 --- /dev/null +++ b/guppy/prelude/_internal.py @@ -0,0 +1,291 @@ +import ast +from typing import Optional, Literal + +from pydantic import BaseModel + +from guppy.ast_util import with_type, AstNode, with_loc, get_type +from guppy.checker.core import Context, CallableVariable +from guppy.checker.expr_checker import ExprSynthesizer, check_num_args +from guppy.custom import ( + CustomCallChecker, + DefaultCallChecker, + CustomFunction, + CustomCallCompiler, +) +from guppy.error import GuppyTypeError, GuppyError +from guppy.gtypes import GuppyType, FunctionType, BoolType +from guppy.hugr import ops, tys, val +from guppy.hugr.hugr import OutPortV +from guppy.nodes import GlobalCall + + +INT_WIDTH = 6 # 2^6 = 64 bit + + +hugr_int_type = tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.BoundedNatArg(n=INT_WIDTH)], + bound=tys.TypeBound.Eq, +) + + +hugr_float_type = tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, +) + + +class ConstIntS(BaseModel): + """Hugr representation of signed integers in the arithmetic extension.""" + + c: Literal["ConstIntS"] = "ConstIntS" + log_width: int + value: int + + +class ConstF64(BaseModel): + """Hugr representation of floats in the arithmetic extension.""" + + c: Literal["ConstF64"] = "ConstF64" + value: float + + +def bool_value(b: bool) -> val.Value: + """Returns the Hugr representation of a boolean value.""" + return val.Sum(tag=int(b), value=val.Tuple(vs=[])) + + +def int_value(i: int) -> val.Value: + """Returns the Hugr representation of an integer value.""" + return val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),)) + + +def float_value(f: float) -> val.Value: + """Returns the Hugr representation of a float value.""" + return val.ExtensionVal(c=(ConstF64(value=f),)) + + +def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops.OpType: + """Utility method to create Hugr logic ops.""" + return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) + + +def int_op( + op_name: str, ext: str = "arithmetic.int", num_params: int = 1 +) -> ops.OpType: + """Utility method to create Hugr integer arithmetic ops.""" + return ops.CustomOp( + extension=ext, + op_name=op_name, + args=num_params * [tys.BoundedNatArg(n=INT_WIDTH)], + ) + + +def float_op(op_name: str, ext: str = "arithmetic.float") -> ops.OpType: + """Utility method to create Hugr integer arithmetic ops.""" + return ops.CustomOp(extension=ext, op_name=op_name, args=[]) + + +class CoercingChecker(DefaultCallChecker): + """Function call type checker that automatically coerces arguments to float.""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + from .builtins import Int + + for i in range(len(args)): + args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) + if isinstance(ty, self.ctx.globals.types["int"]): + call = with_loc( + self.node, GlobalCall(func=Int.__float__, args=[args[i]]) + ) + args[i] = with_type(self.ctx.globals.types["float"].build(), call) + return super().synthesize(args) + + +class ReversingChecker(CustomCallChecker): + """Call checker that reverses the arguments after checking.""" + + base_checker: CustomCallChecker + + def __init__(self, base_checker: Optional[CustomCallChecker] = None): + self.base_checker = base_checker or DefaultCallChecker() + + def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: + super()._setup(ctx, node, func) + self.base_checker._setup(ctx, node, func) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + expr = self.base_checker.check(args, ty) + if isinstance(expr, GlobalCall): + expr.args = list(reversed(args)) + return expr + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + expr, ty = self.base_checker.synthesize(args) + if isinstance(expr, GlobalCall): + expr.args = list(reversed(args)) + return expr, ty + + +class UnsupportedChecker(CustomCallChecker): + """Call checker for Python builtin functions that are not available in Guppy. + + Gives the uses a nicer error message when they try to use an unsupported feature. + """ + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + raise GuppyError( + f"Builtin method `{self.func.name}` is not supported by Guppy", self.node + ) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + raise GuppyError( + f"Builtin method `{self.func.name}` is not supported by Guppy", self.node + ) + + +class DunderChecker(CustomCallChecker): + """Call checker for builtin functions that call out to dunder instance methods""" + + dunder_name: str + num_args: int + + def __init__(self, dunder_name: str, num_args: int = 1): + assert num_args > 0 + self.dunder_name = dunder_name + self.num_args = num_args + + def _get_func( + self, args: list[ast.expr] + ) -> tuple[list[ast.expr], CallableVariable]: + check_num_args(self.num_args, len(args), self.node) + fst, *rest = args + fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) + func = self.ctx.globals.get_instance_func(ty, self.dunder_name) + if func is None: + raise GuppyTypeError( + f"Builtin function `{self.func.name}` is not defined for argument of " + f"type `{ty}`", + self.node.args[0] if isinstance(self.node, ast.Call) else self.node, + ) + return [fst, *rest], func + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + args, func = self._get_func(args) + return func.synthesize_call(args, self.node, self.ctx) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + args, func = self._get_func(args) + return func.check_call(args, ty, self.node, self.ctx) + + +class CallableChecker(CustomCallChecker): + """Call checker for the builtin `callable` function""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + check_num_args(1, len(args), self.node) + [arg] = args + arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) + is_callable = ( + isinstance(ty, FunctionType) + or self.ctx.globals.get_instance_func(ty, "__call__") is not None + ) + const = with_loc(self.node, ast.Constant(value=is_callable)) + return const, BoolType() + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + args, _ = self.synthesize(args) + if not isinstance(ty, BoolType): + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `bool`", self.node + ) + return args + + +class IntTruedivCompiler(CustomCallCompiler): + """Compiler for the `int.__truediv__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Int, Float + + # Compile `truediv` using float arithmetic + [left, right] = args + [left] = Int.__float__.compile_call( + [left], self.dfg, self.graph, self.globals, self.node + ) + [right] = Int.__float__.compile_call( + [right], self.dfg, self.graph, self.globals, self.node + ) + return Float.__truediv__.compile_call( + [left, right], self.dfg, self.graph, self.globals, self.node + ) + + +class FloatBoolCompiler(CustomCallCompiler): + """Compiler for the `float.__bool__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: bool(x) = (x != 0.0) + zero_const = self.graph.add_constant( + float_value(0.0), get_type(self.node), self.dfg.node + ) + zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) + return Float.__ne__.compile_call( + [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node + ) + + +class FloatFloordivCompiler(CustomCallCompiler): + """Compiler for the `float.__floordiv__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: floordiv(x, y) = floor(truediv(x, y)) + [div] = Float.__truediv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [floor] = Float.__floor__.compile_call( + [div], self.dfg, self.graph, self.globals, self.node + ) + return [floor] + + +class FloatModCompiler(CustomCallCompiler): + """Compiler for the `float.__mod__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: mod(x, y) = x - (x // y) * y + [div] = Float.__floordiv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [mul] = Float.__mul__.compile_call( + [div, args[1]], self.dfg, self.graph, self.globals, self.node + ) + [sub] = Float.__sub__.compile_call( + [args[0], mul], self.dfg, self.graph, self.globals, self.node + ) + return [sub] + + +class FloatDivmodCompiler(CustomCallCompiler): + """Compiler for the `__divmod__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: divmod(x, y) = (div(x, y), mod(x, y)) + [div] = Float.__truediv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [mod] = Float.__mod__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] diff --git a/guppy/prelude/boolean.py b/guppy/prelude/boolean.py deleted file mode 100644 index f124d503..00000000 --- a/guppy/prelude/boolean.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Guppy standard extension for bool operations.""" - -# mypy: disable-error-code=empty-body - -from guppy.prelude import builtin -from guppy.prelude.builtin import BoolType -from guppy.extension import ( - GuppyExtension, - OpCompiler, - IdOpCompiler, - NotImplementedCompiler, -) -from guppy.hugr import ops -from guppy.prelude.integer import IntOpCompiler - - -class BoolOpCompiler(OpCompiler): - def __init__(self, op_name: str): - super().__init__(ops.CustomOp(extension="logic", op_name=op_name, args=[])) - - -ext = GuppyExtension("boolean", [builtin]) - - -@ext.func(BoolOpCompiler("And"), instance=BoolType) -def __and__(self: bool, other: bool) -> bool: - ... - - -@ext.func(IdOpCompiler(), instance=BoolType) -def __bool__(self: bool) -> bool: - ... - - -@ext.func(IntOpCompiler("ifrombool"), instance=BoolType) -def __int__(self: bool) -> int: - ... - - -@ext.func(BoolOpCompiler("Or"), instance=BoolType) -def __or__(self: bool, other: bool) -> bool: - ... - - -@ext.func(NotImplementedCompiler(), instance=BoolType) # TODO -def __str__(self: int) -> str: - ... diff --git a/guppy/prelude/builtin.py b/guppy/prelude/builtin.py deleted file mode 100644 index 02dfc5e7..00000000 --- a/guppy/prelude/builtin.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Guppy standard extension for builtin types and methods. - -Instance methods for builtin types are defined in their own files -""" - -import ast -from typing import Union, Literal, Any - -from pydantic import BaseModel - -from guppy.compiler_base import CallCompiler -from guppy.error import GuppyError, GuppyTypeError -from guppy.expression import check_num_args -from guppy.extension import GuppyExtension, NotImplementedCompiler -from guppy.guppy_types import SumType, TupleType, GuppyType, FunctionType -from guppy.hugr import tys, val -from guppy.hugr.hugr import OutPortV -from guppy.hugr.tys import TypeBound - - -extension = GuppyExtension("builtin", []) - - -# We have to define and register the bool type by hand since we want it to be a -# subclass of `SumType` - - -class BoolType(SumType): - """The type of booleans.""" - - linear = False - name = "bool" - - def __init__(self) -> None: - # Hugr bools are encoded as Sum((), ()) - super().__init__([TupleType([]), TupleType([])]) - - @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: - if len(args) > 0: - raise GuppyError("Type `bool` is not parametric", node) - return BoolType() - - def __str__(self) -> str: - return "bool" - - def __eq__(self, other: Any) -> bool: - return isinstance(other, BoolType) - - -extension.register_type("bool", BoolType) - - -INT_WIDTH = 6 # 2^6 = 64 bit - -IntType: type[GuppyType] = extension.new_type( - name="int", - hugr_repr=tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=INT_WIDTH)], - bound=TypeBound.Eq, - ), -) - -FloatType: type[GuppyType] = extension.new_type( - name="float", - hugr_repr=tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=TypeBound.Copyable, - ), -) - -StringType: type[GuppyType] = extension.new_type( - name="str", - hugr_repr=tys.Opaque( - extension="TODO", # String hugr extension doesn't exist yet - id="string", - args=[], - bound=TypeBound.Eq, - ), -) - - -class ConstIntS(BaseModel): - """Hugr representation of signed integers in the arithmetic extension.""" - - c: Literal["ConstIntS"] = "ConstIntS" - log_width: int - value: int - - -class ConstF64(BaseModel): - """Hugr representation of floats in the arithmetic extension.""" - - c: Literal["ConstF64"] = "ConstF64" - value: float - - -def bool_value(b: bool) -> val.Value: - """Returns the Hugr representation of a boolean value.""" - return val.Sum(tag=int(b), value=val.Tuple(vs=[])) - - -def int_value(i: int) -> val.Value: - """Returns the Hugr representation of an integer value.""" - return val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),)) - - -def float_value(f: float) -> val.Value: - """Returns the Hugr representation of a float value.""" - return val.ExtensionVal(c=(ConstF64(value=f),)) - - -class CallableCompiler(CallCompiler): - """Call compiler for the builtin `callable` function""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - check_num_args(1, len(args), self.node) - [arg] = args - is_callable = ( - isinstance(arg.ty, FunctionType) - or self.globals.get_instance_func(arg.ty, "__call__") is not None - ) - const = self.graph.add_constant(bool_value(is_callable), BoolType()).out_port(0) - return [self.graph.add_load_constant(const).out_port(0)] - - -class BuiltinCompiler(CallCompiler): - """Call compiler for builtin functions that call out to dunder instance methods""" - - dunder_name: str - num_args: int - - def __init__(self, dunder_name: str, num_args: int = 1): - self.dunder_name = dunder_name - self.num_args = num_args - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - check_num_args(self.num_args, len(args), self.node) - [arg] = args - func = self.globals.get_instance_func(arg.ty, self.dunder_name) - if func is None: - raise GuppyTypeError( - f"Builtin function `{self.func.name}` is not defined for argument of " - "type `{arg.ty}`", - self.node.args[0] if isinstance(self.node, ast.Call) else self.node, - ) - return func.compile_call(args, self.dfg, self.graph, self.globals, self.node) - - -extension.new_func("abs", BuiltinCompiler("__abs__"), higher_order_value=False) -extension.new_func("bool", BuiltinCompiler("__bool__"), higher_order_value=False) -extension.new_func("callable", CallableCompiler(), higher_order_value=False) -extension.new_func("divmod", BuiltinCompiler("__divmod__"), higher_order_value=False) -extension.new_func("float", BuiltinCompiler("__float__"), higher_order_value=False) -extension.new_func("int", BuiltinCompiler("__int__"), higher_order_value=False) -extension.new_func("len", BuiltinCompiler("__len__"), higher_order_value=False) -extension.new_func( - "pow", BuiltinCompiler("__pow__", num_args=2), higher_order_value=False -) -extension.new_func("repr", BuiltinCompiler("__repr__"), higher_order_value=False) -extension.new_func("round", BuiltinCompiler("__round__"), higher_order_value=False) -extension.new_func("str", BuiltinCompiler("__str__"), higher_order_value=False) - - -# Python builtins that are not supported yet -extension.new_func("aiter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("all", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("anext", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("any", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("ascii", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("aiter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("bin", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("breakpoint", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("bytearray", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("bytes", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("chr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("classmethod", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("compile", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("complex", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("delattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("dict", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("dir", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("enumerate", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("eval", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("exec", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("filter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("format", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("frozenset", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("getattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("globals", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("hasattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("hash", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("help", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("hex", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("id", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("input", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("isinstance", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("issubclass", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("iter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("list", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("locals", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("map", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("max", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("memoryview", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("min", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("next", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("object", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("oct", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("open", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("ord", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("print", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("property", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("range", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("reversed", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("set", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("setattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("slice", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("sorted", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("staticmethod", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("sum", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("super", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("tuple", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("type", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("vars", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("zip", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("__import__", NotImplementedCompiler(), higher_order_value=False) diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py new file mode 100644 index 00000000..25dfc741 --- /dev/null +++ b/guppy/prelude/builtins.py @@ -0,0 +1,731 @@ +"""Guppy module for builtin types and operations.""" + +# mypy: disable-error-code="empty-body, misc, override, no-untyped-def" + +from guppy.custom import NoopCompiler, DefaultCallChecker +from guppy.decorator import guppy +from guppy.gtypes import BoolType +from guppy.hugr import tys, ops +from guppy.module import GuppyModule +from guppy.prelude._internal import ( + logic_op, + int_op, + hugr_int_type, + hugr_float_type, + float_op, + CoercingChecker, + ReversingChecker, + IntTruedivCompiler, + FloatBoolCompiler, + FloatDivmodCompiler, + FloatFloordivCompiler, + FloatModCompiler, + DunderChecker, + CallableChecker, + UnsupportedChecker, +) + + +builtins = GuppyModule("builtins", import_builtins=False) + + +@guppy.extend_type(builtins, BoolType) +class Bool: + @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) + def __and__(self: bool, other: bool) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __bool__(self: bool) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ifrombool")) + def __int__(self: bool) -> int: + ... + + @guppy.hugr_op(builtins, logic_op("Or", [tys.BoundedNatArg(n=2)])) + def __or__(self: bool, other: bool) -> bool: + ... + + +@guppy.type(builtins, hugr_int_type, name="int") +class Int: + @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) + def __abs__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iadd")) + def __add__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iand")) + def __and__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("itobool")) + def __bool__(self: int) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __ceil__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2)) + def __divmod__(self: int, other: int) -> tuple[int, int]: + ... + + @guppy.hugr_op(builtins, int_op("ieq")) + def __eq__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("convert_s", "arithmetic.conversions")) + def __float__(self: int) -> float: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __floor__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2)) + def __floordiv__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ige_s")) + def __ge__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("igt_s")) + def __gt__(self: int, other: int) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __int__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("inot")) + def __invert__(self: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ile_s")) + def __le__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ishl", num_params=2)) # TODO: RHS is unsigned + def __lshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ilt_s")) + def __lt__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2)) + def __mod__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imul")) + def __mul__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ine")) + def __ne__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ineg")) + def __neg__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ior")) + def __or__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __pos__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="ipow")) # TODO + def __pow__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker()) + def __radd__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker()) + def __rand__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker()) + def __rdivmod__(self: int, other: int) -> tuple[int, int]: + ... + + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker()) + def __rfloordiv__(self: int, other: int) -> int: + ... + + @guppy.hugr_op( + builtins, int_op("ishl", num_params=2), ReversingChecker() + ) # TODO: RHS is unsigned + def __rlshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker()) + def __rmod__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker()) + def __rmul__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ior"), ReversingChecker()) + def __ror__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __round__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="ipow"), ReversingChecker()) # TODO + def __rpow__(self: int, other: int) -> int: + ... + + @guppy.hugr_op( + builtins, int_op("ishr", num_params=2), ReversingChecker() + ) # TODO: RHS is unsigned + def __rrshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ishr", num_params=2)) # TODO: RHS is unsigned + def __rshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker()) + def __rsub__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker()) + def __rtruediv__(self: int, other: int) -> float: + ... + + @guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker()) + def __rxor__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("isub")) + def __sub__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, IntTruedivCompiler()) + def __truediv__(self: int, other: int) -> float: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __trunc__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ixor")) + def __xor__(self: int, other: int) -> int: + ... + + +@guppy.type(builtins, hugr_float_type, name="float") +class Float: + @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) + def __abs__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fadd"), CoercingChecker()) + def __add__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatBoolCompiler(), CoercingChecker()) + def __bool__(self: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fceil"), CoercingChecker()) + def __ceil__(self: float) -> float: + ... + + @guppy.custom(builtins, FloatDivmodCompiler(), CoercingChecker()) + def __divmod__(self: float, other: float) -> tuple[float, float]: + ... + + @guppy.hugr_op(builtins, float_op("feq"), CoercingChecker()) + def __eq__(self: float, other: float) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) + def __float__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("ffloor"), CoercingChecker()) + def __floor__(self: float) -> float: + ... + + @guppy.custom(builtins, FloatFloordivCompiler(), CoercingChecker()) + def __floordiv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fge"), CoercingChecker()) + def __ge__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fgt"), CoercingChecker()) + def __gt__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) + def __int__(self: float) -> int: + ... + + @guppy.hugr_op(builtins, float_op("fle"), CoercingChecker()) + def __le__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("flt"), CoercingChecker()) + def __lt__(self: float, other: float) -> bool: + ... + + @guppy.custom(builtins, FloatModCompiler(), CoercingChecker()) + def __mod__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fmul"), CoercingChecker()) + def __mul__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fne"), CoercingChecker()) + def __ne__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fneg"), CoercingChecker()) + def __neg__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) + def __pos__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="fpow")) # TODO + def __pow__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fadd"), ReversingChecker(CoercingChecker())) + def __radd__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatDivmodCompiler(), ReversingChecker(CoercingChecker())) + def __rdivmod__(self: float, other: float) -> tuple[float, float]: + ... + + @guppy.custom( + builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker()) + ) + def __rfloordiv__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatModCompiler(), ReversingChecker(CoercingChecker())) + def __rmod__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fmul"), ReversingChecker(CoercingChecker())) + def __rmul__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="fround")) # TODO + def __round__(self: float) -> float: + ... + + @guppy.hugr_op( + builtins, ops.DummyOp(name="fpow"), ReversingChecker(DefaultCallChecker()) + ) # TODO + def __rpow__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fsub"), ReversingChecker(CoercingChecker())) + def __rsub__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fdiv"), ReversingChecker(CoercingChecker())) + def __rtruediv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fsub"), CoercingChecker()) + def __sub__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fdiv"), CoercingChecker()) + def __truediv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) + def __trunc__(self: float) -> float: + ... + + +@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False) +def abs(x): + ... + + +@guppy.custom( + builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False +) +def _bool(x): + ... + + +@guppy.custom(builtins, checker=CallableChecker(), higher_order_value=False) +def callable(x): + ... + + +@guppy.custom( + builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False +) +def divmod(x, y): + ... + + +@guppy.custom( + builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False +) +def _float(x, y): + ... + + +@guppy.custom( + builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False +) +def _int(x): + ... + + +@guppy.custom( + builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False +) +def pow(x, y): + ... + + +@guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) +def round(x): + ... + + +# Python builtins that are not supported yet: + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def aiter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def all(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def anext(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def any(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bin(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def breakpoint(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bytearray(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bytes(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def chr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def classmethod(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def compile(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def complex(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def delattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def dict(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def dir(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def enumerate(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def eval(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def exec(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def filter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def format(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def forozenset(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def getattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def globals(): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hasattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hash(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def help(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hex(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def id(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def input(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def isinstance(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def issubclass(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def iter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def len(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def list(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def locals(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def map(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def max(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def memoryview(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def min(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def next(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def object(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def oct(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def open(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def ord(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def print(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def property(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def range(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def repr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def reversed(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def set(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def setattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def slice(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def sorted(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def staticmethod(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def str(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def sum(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def super(x): + ... + + +@guppy.custom( + builtins, name="tuple", checker=UnsupportedChecker(), higher_order_value=False +) +def _tuple(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def type(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def vars(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def zip(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def __import__(x): + ... diff --git a/guppy/prelude/float.py b/guppy/prelude/float.py deleted file mode 100644 index e59369eb..00000000 --- a/guppy/prelude/float.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Guppy standard extension for float operations.""" - -# mypy: disable-error-code=empty-body - -from guppy.prelude import builtin -from guppy.prelude.builtin import IntType, FloatType, float_value -from guppy.compiler_base import CallCompiler -from guppy.extension import ( - GuppyExtension, - OpCompiler, - Reversed, - NotImplementedCompiler, - IdOpCompiler, -) -from guppy.hugr import ops -from guppy.hugr.hugr import OutPortV - - -class FloatOpCompiler(OpCompiler): - """Compiler for calls that can be implemented via a single Hugr float op""" - - def __init__(self, op_name: str, extension: str = "arithmetic.float"): - super().__init__(ops.CustomOp(extension=extension, op_name=op_name, args=[])) - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - args = [ - self.graph.add_node( - ops.CustomOp(extension="arithmetic.conversions", op_name="convert_s"), - inputs=[arg], - parent=self.parent, - ).add_out_port(FloatType()) - if isinstance(arg.ty, IntType) - else arg - for arg in args - ] - return super().compile(args) - - -class BoolCompiler(CallCompiler): - """Compiler for the `__bool__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: bool(x) = (x != 0.0) - zero_const = self.graph.add_constant(float_value(0.0), FloatType(), self.parent) - zero = self.graph.add_load_constant(zero_const.out_port(0), self.parent) - return __ne__.compile_call( - [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node - ) - - -class FloordivCompiler(CallCompiler): - """Compiler for the `__floordiv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: floordiv(x, y) = floor(truediv(x, y)) - [div] = __truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [floor] = __floor__.compile_call( - [div], self.dfg, self.graph, self.globals, self.node - ) - return [floor] - - -class ModCompiler(CallCompiler): - """Compiler for the `__mod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: mod(x, y) = x - (x // y) * y - [div] = __floordiv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [mul] = __mul__.compile_call( - [div, args[1]], self.dfg, self.graph, self.globals, self.node - ) - [sub] = __sub__.compile_call( - [args[0], mul], self.dfg, self.graph, self.globals, self.node - ) - return [sub] - - -class DivmodCompiler(CallCompiler): - """Compiler for the `__divmod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: divmod(x, y) = (div(x, y), mod(x, y)) - [div] = __truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [mod] = __mod__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - return [self.graph.add_make_tuple([div, mod], self.parent).out_port(0)] - - -extension = GuppyExtension("float", [builtin]) - - -@extension.func(FloatOpCompiler("fabs"), instance=FloatType) -def __abs__(self: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fadd"), instance=FloatType) -def __add__(self: float, other: float) -> float: - ... - - -@extension.func(BoolCompiler(), instance=FloatType) -def __bool__(self: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("fceil"), instance=FloatType) -def __ceil__(self: float) -> float: - ... - - -@extension.func(DivmodCompiler(), instance=FloatType) -def __divmod__(self: float, other: float) -> tuple[float, float]: - ... - - -@extension.func(FloatOpCompiler("feq"), instance=FloatType) -def __eq__(self: float, other: float) -> bool: - ... - - -@extension.func(IdOpCompiler(), instance=FloatType) -def __float__(self: float) -> float: - ... - - -@extension.func(FloatOpCompiler("ffloor"), instance=FloatType) -def __floor__(self: float) -> float: - ... - - -@extension.func(FloordivCompiler(), instance=FloatType) -def __floordiv__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fge"), instance=FloatType) -def __ge__(self: float, other: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("fgt"), instance=FloatType) -def __gt__(self: float, other: float) -> bool: - ... - - -@extension.func( - FloatOpCompiler("trunc_s", "arithmetic.conversions"), instance=FloatType -) -def __int__(self: float, other: float) -> int: - ... - - -@extension.func(FloatOpCompiler("fle"), instance=FloatType) -def __le__(self: float, other: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("flt"), instance=FloatType) -def __lt__(self: float, other: float) -> bool: - ... - - -@extension.func(ModCompiler(), instance=FloatType) -def __mod__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fmul"), instance=FloatType) -def __mul__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fne"), instance=FloatType) -def __ne__(self: float, other: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("fneg"), instance=FloatType) -def __neg__(self: float, other: float) -> float: - ... - - -@extension.func(IdOpCompiler(), instance=FloatType) -def __pos__(self: float) -> float: - ... - - -@extension.func(NotImplementedCompiler(), instance=FloatType) # TODO -def __pow__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(FloatOpCompiler("fadd")), instance=FloatType) -def __radd__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(DivmodCompiler()), instance=FloatType) -def __rdivmod__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(ModCompiler()), instance=FloatType) -def __rmod__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(FloatOpCompiler("fmul")), instance=FloatType) -def __rmul__(self: float, other: float) -> float: - ... - - -@extension.func(NotImplementedCompiler(), instance=FloatType) # TODO -def __round__(self: float) -> float: - ... - - -@extension.func(Reversed(NotImplementedCompiler()), instance=FloatType) # TODO -def __rpow__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(FloatOpCompiler("fsub")), instance=FloatType) -def __rsub__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fdiv"), instance=FloatType) -def __rtruediv__(self: float, other: float) -> float: - ... - - -@extension.func(NotImplementedCompiler(), instance=FloatType) # TODO -def __str__(self: float) -> str: - ... - - -@extension.func(FloatOpCompiler("fsub"), instance=FloatType) -def __sub__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fdiv"), instance=FloatType) -def __truediv__(self: float, other: float) -> float: - ... - - -@extension.func( - FloatOpCompiler("trunc_s", "arithmetic.conversions"), instance=FloatType -) -def __trunc__(self: float, other: float) -> int: - ... diff --git a/guppy/prelude/integer.py b/guppy/prelude/integer.py deleted file mode 100644 index 182884aa..00000000 --- a/guppy/prelude/integer.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Guppy standard extension for int operations.""" - -# mypy: disable-error-code=empty-body - -from guppy.compiler_base import CallCompiler -from guppy.expression import type_check_call -from guppy.guppy_types import FunctionType -from guppy.hugr.hugr import OutPortV -from guppy.prelude import builtin -from guppy.prelude.builtin import IntType, INT_WIDTH, FloatType -from guppy.extension import ( - GuppyExtension, - OpCompiler, - Reversed, - NotImplementedCompiler, - IdOpCompiler, -) -from guppy.hugr import ops, tys - - -class IntOpCompiler(OpCompiler): - """Compiler for calls that can be implemented via a single Hugr integer op""" - - def __init__(self, op_name: str, ext: str = "arithmetic.int", num_params: int = 1): - super().__init__( - ops.CustomOp( - extension=ext, - op_name=op_name, - args=num_params * [tys.BoundedNatArg(n=INT_WIDTH)], - ) - ) - - -class TruedivCompiler(CallCompiler): - """Compiler for the `__truediv__` function""" - - signature: FunctionType = FunctionType([IntType(), IntType()], [FloatType()]) - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # Compile `truediv` using float arithmetic - import guppy.prelude.float - - type_check_call(self.signature, args, self.node) - [left, right] = args - [left] = __float__.compile_call( - [left], self.dfg, self.graph, self.globals, self.node - ) - [right] = __float__.compile_call( - [right], self.dfg, self.graph, self.globals, self.node - ) - return guppy.prelude.float.__truediv__.compile_call( - [left, right], self.dfg, self.graph, self.globals, self.node - ) - - -extension = GuppyExtension("integer", dependencies=[builtin]) - - -# TODO: Maybe wrong?? (signed vs unsigned!) -@extension.func(IntOpCompiler("iabs"), instance=IntType) -def __abs__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("iadd"), instance=IntType) -def __add__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("iand"), instance=IntType) -def __and__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("itobool"), instance=IntType) -def __bool__(self: int) -> bool: - ... - - -@extension.func(OpCompiler(ops.Noop(ty=IntType().to_hugr())), instance=IntType) -def __ceil__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("idivmod_s", num_params=2), instance=IntType) -def __divmod__(self: int, other: int) -> tuple[int, int]: - ... - - -@extension.func(IntOpCompiler("ieq"), instance=IntType) -def __eq__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("convert_s", "arithmetic.conversions"), instance=IntType) -def __float__(self: int) -> float: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __floor__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("idiv_s", num_params=2), instance=IntType) -def __floordiv__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ige_s"), instance=IntType) -def __ge__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("igt_s"), instance=IntType) -def __gt__(self: int, other: int) -> bool: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __int__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("inot"), instance=IntType) -def __invert__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("ile_s"), instance=IntType) -def __le__(self: int, other: int) -> bool: - ... - - -@extension.func( - IntOpCompiler("ishl", num_params=2), instance=IntType -) # TODO: broken (RHS is unsigned) -def __lshift__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ilt_s"), instance=IntType) -def __lt__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("imod_s", num_params=2), instance=IntType) -def __mod__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("imul"), instance=IntType) -def __mul__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ine"), instance=IntType) -def __ne__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("ineg"), instance=IntType) -def __neg__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("ior"), instance=IntType) -def __or__(self: int, other: int) -> int: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __pos__(self: int) -> int: - ... - - -@extension.func(NotImplementedCompiler(), instance=IntType) # TODO -def __pow__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("iadd")), instance=IntType) -def __radd__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("iand")), instance=IntType) -def __rand__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("idivmod_s", num_params=2)), instance=IntType) -def __rdivmod__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("idiv_s", num_params=2)), instance=IntType) -def __rfloordiv__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("ishl", num_params=2)), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rlshift__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("imod_s", num_params=2)), instance=IntType) -def __rmod__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("imul")), instance=IntType) -def __rmul__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("ior")), instance=IntType) -def __ror__(self: int, other: int) -> int: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __round__(self: int) -> int: - ... - - -@extension.func(Reversed(NotImplementedCompiler()), instance=IntType) # TODO -def __rpow__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("ishr", num_params=2)), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rrshift__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("ishr", num_params=2)), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rshift__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("isub")), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rsub__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(TruedivCompiler()), instance=IntType) -def __rtruediv__(self: int, other: int) -> float: - ... - - -@extension.func(Reversed(IntOpCompiler("ixor")), instance=IntType) -def __rxor__(self: int, other: int) -> int: - ... - - -@extension.func(NotImplementedCompiler(), instance=IntType) # TODO -def __str__(self: int) -> str: - ... - - -@extension.func(IntOpCompiler("isub"), instance=IntType) -def __sub__(self: int, other: int) -> int: - ... - - -@extension.func(TruedivCompiler(), instance=IntType) -def __truediv__(self: int, other: int) -> float: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __trunc__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ixor"), instance=IntType) -def __xor__(self: int, other: int) -> int: - ... diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index c0fcea8d..3b40af88 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -1,64 +1,65 @@ -"""Guppy standard extension for quantum operations.""" +"""Guppy standard module for quantum operations.""" # mypy: disable-error-code=empty-body +from guppy.decorator import guppy +from guppy.hugr import tys, ops from guppy.hugr.tys import TypeBound -from guppy.prelude import builtin -from guppy.extension import GuppyExtension, OpCompiler -from guppy.hugr import ops, tys +from guppy.module import GuppyModule -class QuantumOpCompiler(OpCompiler): - def __init__(self, op_name: str, ext: str = "quantum.tket2"): - super().__init__(ops.CustomOp(extension=ext, op_name=op_name, args=[])) +quantum = GuppyModule("quantum") -_hugr_qubit = tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any) +def quantum_op(op_name: str) -> ops.OpType: + """Utility method to create Hugr quantum ops.""" + return ops.CustomOp(extension="quantum.tket2", op_name=op_name, args=[]) -extension = GuppyExtension("quantum.tket2", dependencies=[builtin]) - - -@extension.type(_hugr_qubit, linear=True) +@guppy.type( + quantum, + tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), + linear=True, +) class Qubit: pass -@extension.func(QuantumOpCompiler("H")) +@guppy.hugr_op(quantum, quantum_op("H")) def h(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("CX")) +@guppy.hugr_op(quantum, quantum_op("CX")) def cx(control: Qubit, target: Qubit) -> tuple[Qubit, Qubit]: ... -@extension.func(QuantumOpCompiler("RzF64")) +@guppy.hugr_op(quantum, quantum_op("RzF64")) def rz(q: Qubit, angle: float) -> Qubit: ... -@extension.func(QuantumOpCompiler("Measure")) +@guppy.hugr_op(quantum, quantum_op("Measure")) def measure(q: Qubit) -> tuple[Qubit, bool]: ... -@extension.func(QuantumOpCompiler("T")) +@guppy.hugr_op(quantum, quantum_op("T")) def t(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("Tdg")) +@guppy.hugr_op(quantum, quantum_op("Tdg")) def tdg(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("Z")) +@guppy.hugr_op(quantum, quantum_op("Z")) def z(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("X")) +@guppy.hugr_op(quantum, quantum_op("X")) def x(q: Qubit) -> Qubit: ... diff --git a/guppy/statement.py b/guppy/statement.py deleted file mode 100644 index 502f0207..00000000 --- a/guppy/statement.py +++ /dev/null @@ -1,177 +0,0 @@ -import ast -from typing import Sequence - -from guppy.ast_util import AstVisitor, AstNode, set_location_from -from guppy.cfg.bb import BB, NestedFunctionDef -from guppy.compiler_base import ( - CompilerBase, - DFContainer, - Variable, - return_var, - Globals, -) -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.expression import ExpressionCompiler -from guppy.guppy_types import TupleType, TypeRow, GuppyType -from guppy.hugr.hugr import OutPortV, Hugr - - -class StatementCompiler(CompilerBase, AstVisitor[None]): - """A compiler for non-control-flow statements occurring in a basic block. - - Control-flow statements like loops or if-statements are not handled by this - compiler. They should be turned into a CFG made up of multiple simple basic blocks. - """ - - expr_compiler: ExpressionCompiler - - bb: BB - dfg: DFContainer - return_tys: list[GuppyType] - - def __init__(self, graph: Hugr, globals: Globals): - super().__init__(graph, globals) - self.expr_compiler = ExpressionCompiler(graph, globals) - - def compile_stmts( - self, - stmts: Sequence[ast.stmt], - bb: BB, - dfg: DFContainer, - return_tys: list[GuppyType], - ) -> DFContainer: - """Compiles a list of basic statements into a dataflow node. - - Note that the `dfg` is mutated in-place. After compilation, the DFG will also - contain all variables that are assigned in the given list of statements. - """ - self.bb = bb - self.dfg = dfg - self.return_tys = return_tys - for s in stmts: - self.visit(s) - return self.dfg - - def visit_Assign(self, node: ast.Assign) -> None: - if len(node.targets) > 1: - # This is the case for assignments like `a = b = 1` - raise GuppyError("Multi assignment not supported", node) - target = node.targets[0] - row = self.expr_compiler.compile_row(node.value, self.dfg) - if len(row) == 0: - # In Python it's fine to assign a void return with the variable being bound - # to `None` afterward. At the moment, we don't have a `None` type in Guppy, - # so we raise an error for now. - # TODO: Think about this. Maybe we should uniformly treat `None` as the - # empty tuple? - raise GuppyError("Cannot unpack empty row") - assert len(row) > 0 - - # Helper function to unpack the row based on the LHS pattern - def unpack(pattern: AstNode, ports: list[OutPortV]) -> None: - # Easiest case is if the LHS pattern is a single variable. Note we - # implicitly pack the row into a tuple if it has more than one element. - # I.e. `x = 1, 2` works just like `x = (1, 2)`. - if isinstance(pattern, ast.Name): - port = ( - ports[0] - if len(ports) == 1 - else self.graph.add_make_tuple( - inputs=ports, parent=self.dfg.node - ).out_port(0) - ) - # Check if we override an unused linear variable - x = pattern.id - if x in self.dfg: - var = self.dfg[x] - if var.ty.linear and var.used is None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` " - "is not used", - var.defined_at, - ) - self.dfg[x] = Variable(x, port, node) - # The only other thing we support right now are tuples - elif isinstance(pattern, ast.Tuple): - if len(ports) == 1 and isinstance(ports[0].ty, TupleType): - ports = list( - self.graph.add_unpack_tuple( - input_tuple=ports[0], parent=self.dfg.node - ).out_ports - ) - n, m = len(pattern.elts), len(ports) - if n != m: - raise GuppyTypeError( - f"{'Too many' if n < m else 'Not enough'} " - f"values to unpack (expected {n}, got {m})", - node, - ) - for pat, port in zip(pattern.elts, ports): - unpack(pat, [port]) - # TODO: Python also supports assignments like `[a, b] = [1, 2]` or - # `a, *b = ...`. The former would require some runtime checks but - # the latter should be easier to do (unpack and repack the rest). - else: - raise GuppyError("Assignment pattern not supported", pattern) - - unpack(target, row) - - def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - # TODO: Figure out what to do with type annotations - raise NotImplementedError() - - def visit_AugAssign(self, node: ast.AugAssign) -> None: - bin_op = ast.BinOp(left=node.target, op=node.op, right=node.value) - set_location_from(bin_op, node) - assign = ast.Assign(targets=[node.target], value=bin_op) - set_location_from(assign, node) - self.visit_Assign(assign) - - def visit_Expr(self, node: ast.Expr) -> None: - self.expr_compiler.compile_row(node.value, self.dfg) - - def visit_Return(self, node: ast.Return) -> None: - if node.value is None: - row = [] - else: - port = self.expr_compiler.compile(node.value, self.dfg) - # Top-level tuples are unpacked, i.e. turned into a row - if isinstance(port.ty, TupleType): - unpack = self.graph.add_unpack_tuple(port, self.dfg.node) - row = [unpack.out_port(i) for i in range(len(port.ty.element_types))] - else: - row = [port] - - tys = [p.ty for p in row] - if tys != self.return_tys: - raise GuppyTypeError( - f"Return type mismatch: expected `{TypeRow(self.return_tys)}`, " - f"got `{TypeRow(tys)}`", - node.value, - ) - - # We turn returns into assignments of dummy variables, i.e. the statement - # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`. - for i, port in enumerate(row): - name = return_var(i) - self.dfg[name] = Variable(name, port, node) - - def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: - from guppy.function import FunctionDefCompiler - - func = FunctionDefCompiler(self.graph, self.globals).compile_local( - node, self.dfg, self.bb - ) - self.dfg[node.name] = Variable(node.name, func.port, node) - - def visit_If(self, node: ast.If) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_While(self, node: ast.While) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_Break(self, node: ast.Break) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_Continue(self, node: ast.Continue) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") diff --git a/tests/error/linear_errors/branch_use.py b/tests/error/linear_errors/branch_use.py index 6ad6b73b..9ba256c1 100644 --- a/tests/error/linear_errors/branch_use.py +++ b/tests/error/linear_errors/branch_use.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,12 +8,12 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... -@guppy(module) +@guppy.declare(module) def measure(q: Qubit) -> bool: ... @@ -26,4 +26,4 @@ def foo(b: bool) -> bool: return False -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/break_unused.py b/tests/error/linear_errors/break_unused.py index 81e18bdf..b66d2ba4 100644 --- a/tests/error/linear_errors/break_unused.py +++ b/tests/error/linear_errors/break_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,12 +8,12 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... -@guppy(module) +@guppy.declare(module) def measure() -> bool: ... @@ -30,4 +30,4 @@ def foo(i: int) -> bool: return b -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/continue_unused.py b/tests/error/linear_errors/continue_unused.py index 57e12b53..4e6194f8 100644 --- a/tests/error/linear_errors/continue_unused.py +++ b/tests/error/linear_errors/continue_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,12 +8,12 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... -@guppy(module) +@guppy.declare(module) def measure() -> bool: ... @@ -30,4 +30,4 @@ def foo(i: int) -> bool: return b -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/copy.py b/tests/error/linear_errors/copy.py index c4d13986..1f726112 100644 --- a/tests/error/linear_errors/copy.py +++ b/tests/error/linear_errors/copy.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -13,4 +13,4 @@ def foo(q: Qubit) -> tuple[Qubit, Qubit]: return q, q -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/if_both_unused.py b/tests/error/linear_errors/if_both_unused.py index 459a3f5d..4ce0fe88 100644 --- a/tests/error/linear_errors/if_both_unused.py +++ b/tests/error/linear_errors/if_both_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -22,4 +22,4 @@ def foo(b: bool) -> int: return 42 -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/if_both_unused_reassign.py b/tests/error/linear_errors/if_both_unused_reassign.py index 3858e087..dcd83e66 100644 --- a/tests/error/linear_errors/if_both_unused_reassign.py +++ b/tests/error/linear_errors/if_both_unused_reassign.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -23,4 +23,4 @@ def foo(b: bool) -> Qubit: return q -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/reassign_unused.py b/tests/error/linear_errors/reassign_unused.py index 402ff729..7a936455 100644 --- a/tests/error/linear_errors/reassign_unused.py +++ b/tests/error/linear_errors/reassign_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -19,4 +19,4 @@ def foo(q: Qubit) -> Qubit: return q -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/reassign_unused_tuple.py b/tests/error/linear_errors/reassign_unused_tuple.py index a888a67f..162e4163 100644 --- a/tests/error/linear_errors/reassign_unused_tuple.py +++ b/tests/error/linear_errors/reassign_unused_tuple.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -19,4 +19,4 @@ def foo(q: Qubit) -> tuple[Qubit, Qubit]: return q, r -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/unused.py b/tests/error/linear_errors/unused.py index f45f2141..a22b8f16 100644 --- a/tests/error/linear_errors/unused.py +++ b/tests/error/linear_errors/unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -14,4 +14,4 @@ def foo(q: Qubit) -> int: return 42 -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/unused_expr.err b/tests/error/linear_errors/unused_expr.err new file mode 100644 index 00000000..6a0ae95e --- /dev/null +++ b/tests/error/linear_errors/unused_expr.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(q: Qubit) -> None: +13: h(q) + ^^^^ +GuppyTypeError: Value with linear type `Qubit` is not used diff --git a/tests/error/linear_errors/unused_expr.py b/tests/error/linear_errors/unused_expr.py new file mode 100644 index 00000000..6ad60a82 --- /dev/null +++ b/tests/error/linear_errors/unused_expr.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(q: Qubit) -> None: + h(q) + + +module.compile() diff --git a/tests/error/linear_errors/unused_same_block.py b/tests/error/linear_errors/unused_same_block.py index ccf8ca0b..55031f6b 100644 --- a/tests/error/linear_errors/unused_same_block.py +++ b/tests/error/linear_errors/unused_same_block.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -15,4 +15,4 @@ def foo(q: Qubit) -> int: return x -module.compile(True) +module.compile() diff --git a/tests/error/nested_errors/different_types_if.err b/tests/error/nested_errors/different_types_if.err index 0426774e..da397a08 100644 --- a/tests/error/nested_errors/different_types_if.err +++ b/tests/error/nested_errors/different_types_if.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:13 12: 13: return bar() ^^^ -GuppyError: Variable `bar` can refer to different types: `None -> int` (at 7:8) vs `None -> bool` (at 10:8) +GuppyError: Variable `bar` can refer to different types: `() -> int` (at 7:8) vs `() -> bool` (at 10:8) diff --git a/tests/error/nested_errors/linear_capture.py b/tests/error/nested_errors/linear_capture.py index 1545c31b..2c89d1fc 100644 --- a/tests/error/nested_errors/linear_capture.py +++ b/tests/error/nested_errors/linear_capture.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -16,4 +16,4 @@ def bar() -> Qubit: return q -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/and_not_bool_left.py b/tests/error/type_errors/and_not_bool_left.py index 92647c8e..30f3b338 100644 --- a/tests/error/type_errors/and_not_bool_left.py +++ b/tests/error/type_errors/and_not_bool_left.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool, y: bool) -> bool: return x and y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/and_not_bool_right.py b/tests/error/type_errors/and_not_bool_right.py index 23637899..deb16df7 100644 --- a/tests/error/type_errors/and_not_bool_right.py +++ b/tests/error/type_errors/and_not_bool_right.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: bool, y: NonBool) -> bool: return x and y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/fun_ty_mismatch_1.err b/tests/error/type_errors/fun_ty_mismatch_1.err index 3e4cb8ef..a21337a1 100644 --- a/tests/error/type_errors/fun_ty_mismatch_1.err +++ b/tests/error/type_errors/fun_ty_mismatch_1.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:11 10: 11: return bar ^^^ -GuppyTypeError: Return type mismatch: expected `int -> int`, got `int -> bool` +GuppyTypeError: Expected return value of type `int -> int`, got `int -> bool` diff --git a/tests/error/type_errors/fun_ty_mismatch_2.err b/tests/error/type_errors/fun_ty_mismatch_2.err index 1f578dbd..ee3f83f4 100644 --- a/tests/error/type_errors/fun_ty_mismatch_2.err +++ b/tests/error/type_errors/fun_ty_mismatch_2.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:11 10: 11: return bar(foo) ^^^ -GuppyTypeError: Expected argument of type `None -> int`, got `int -> int` +GuppyTypeError: Expected argument of type `() -> int`, got `int -> int` diff --git a/tests/error/type_errors/if_expr_not_bool.py b/tests/error/type_errors/if_expr_not_bool.py index da546a40..41648464 100644 --- a/tests/error/type_errors/if_expr_not_bool.py +++ b/tests/error/type_errors/if_expr_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool) -> int: return 1 if x else 0 -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/if_not_bool.py b/tests/error/type_errors/if_not_bool.py index 3233d405..a51125fd 100644 --- a/tests/error/type_errors/if_not_bool.py +++ b/tests/error/type_errors/if_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -14,4 +14,4 @@ def foo(x: NonBool) -> int: return 1 -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/not_not_bool.py b/tests/error/type_errors/not_not_bool.py index f8eea0bd..03e9e083 100644 --- a/tests/error/type_errors/not_not_bool.py +++ b/tests/error/type_errors/not_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool) -> bool: return not x -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/or_not_bool_left.py b/tests/error/type_errors/or_not_bool_left.py index b7a9c2c6..9236e3b0 100644 --- a/tests/error/type_errors/or_not_bool_left.py +++ b/tests/error/type_errors/or_not_bool_left.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool, y: bool) -> bool: return x or y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/or_not_bool_right.py b/tests/error/type_errors/or_not_bool_right.py index 407f46b3..ebd7c092 100644 --- a/tests/error/type_errors/or_not_bool_right.py +++ b/tests/error/type_errors/or_not_bool_right.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: bool, y: NonBool) -> bool: return x or y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/return_mismatch.err b/tests/error/type_errors/return_mismatch.err index d7fa70cf..a59582be 100644 --- a/tests/error/type_errors/return_mismatch.err +++ b/tests/error/type_errors/return_mismatch.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:6 5: def foo() -> bool: 6: return 42 ^^ -GuppyTypeError: Return type mismatch: expected `bool`, got `int` +GuppyTypeError: Expected return value of type `bool`, got `int` diff --git a/tests/error/type_errors/while_not_bool.py b/tests/error/type_errors/while_not_bool.py index e12ea59c..2fe00a00 100644 --- a/tests/error/type_errors/while_not_bool.py +++ b/tests/error/type_errors/while_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -14,4 +14,4 @@ def foo(x: NonBool) -> int: return 0 -module.compile(True) +module.compile() diff --git a/tests/error/util.py b/tests/error/util.py index 8fc76028..a3e39191 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -1,21 +1,22 @@ import importlib.util import pathlib -from typing import Callable, Optional, Any, TypeVar - import pytest -from guppy.compiler import GuppyModule -from guppy.extension import GuppyExtension +from typing import Callable, Optional, Any + from guppy.hugr import tys -from guppy.hugr.hugr import Hugr from guppy.hugr.tys import TypeBound +from guppy.module import GuppyModule +from guppy.hugr.hugr import Hugr + +import guppy.decorator as decorator def guppy(f: Callable[..., Any]) -> Optional[Hugr]: """ Decorator to compile functions outside of modules for testing. """ module = GuppyModule("module") - module.register_func(f) - return module.compile(exit_on_error=True) + module.register_func_def(f) + return module.compile() def run_error_test(file, capsys): @@ -35,5 +36,10 @@ def run_error_test(file, capsys): assert err == exp_err -ext = GuppyExtension("test", []) -NonBool = ext.new_type("NonBool", tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)) +util = GuppyModule("test") + + +@decorator.guppy.type(util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)) +class NonBool: + pass + diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index 7125c500..e1efaab9 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -1,19 +1,15 @@ -import pytest - -from guppy.error import UndefinedPort, InternalGuppyError -from guppy.guppy_types import FunctionType +from guppy.gtypes import FunctionType, BoolType, TupleType from guppy.hugr import ops from guppy.hugr.hugr import Hugr -from guppy.prelude.builtin import BoolType, IntType def test_single_dummy(): g = Hugr() - defn = g.add_def(FunctionType([IntType()], [IntType()]), g.root, "test") + defn = g.add_def(FunctionType([BoolType()], [BoolType()]), g.root, "test") dfg = g.add_dfg(defn) - inp = g.add_input([IntType()], dfg).out_port(0) + inp = g.add_input([BoolType()], dfg).out_port(0) dummy = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[IntType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg ) g.add_output([dummy.out_port(0)], parent=dfg) @@ -24,11 +20,11 @@ def test_single_dummy(): def test_unique_names(): g = Hugr() - defn = g.add_def(FunctionType([IntType()], [IntType(), BoolType]), g.root, "test") + defn = g.add_def(FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test") dfg = g.add_dfg(defn) - inp = g.add_input([IntType()], dfg).out_port(0) + inp = g.add_input([BoolType()], dfg).out_port(0) dummy1 = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[IntType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg ) dummy2 = g.add_node( ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py index d981840e..77e07ef4 100644 --- a/tests/hugr/test_ports.py +++ b/tests/hugr/test_ports.py @@ -1,7 +1,7 @@ import pytest from guppy.error import UndefinedPort, InternalGuppyError -from guppy.prelude.builtin import BoolType +from guppy.gtypes import BoolType def test_undefined_port(): diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 7ed1c6f1..c4112fd1 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -1,4 +1,4 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy def test_arith_basic(validate): @@ -17,7 +17,7 @@ def const() -> float: validate(const) -def test_ann_assign(validate): +def test_aug_assign(validate): @guppy def add(x: int) -> int: x += 1 diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index d29fdb7f..cde2801a 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -1,5 +1,6 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy from guppy.hugr import ops +from guppy.module import GuppyModule def test_id(validate): @@ -70,9 +71,11 @@ def func_name() -> None: def test_func_decl_name(): - @guppy + module = GuppyModule("test") + + @guppy.declare(module) def func_name() -> None: ... - [def_op] = [n.op for n in func_name.nodes() if isinstance(n.op, ops.FuncDecl)] + [def_op] = [n.op for n in module.compile().nodes() if isinstance(n.op, ops.FuncDecl)] assert def_op.name == "func_name" diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index 2fe6bf75..bad4d824 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -1,4 +1,5 @@ -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule def test_call(validate): @@ -12,7 +13,7 @@ def foo() -> int: def bar() -> int: return foo() - validate(module.compile(exit_on_error=True)) + validate(module.compile()) def test_call_back(validate): @@ -26,7 +27,7 @@ def foo(x: int) -> int: def bar(x: int) -> int: return x - validate(module.compile(exit_on_error=True)) + validate(module.compile()) def test_recursion(validate): @@ -48,7 +49,7 @@ def foo(x: int) -> int: def bar(x: int) -> int: return foo(x) - validate(module.compile(exit_on_error=True)) + validate(module.compile()) diff --git a/tests/integration/test_functional.py b/tests/integration/test_functional.py index 38aa9340..8c462d6b 100644 --- a/tests/integration/test_functional.py +++ b/tests/integration/test_functional.py @@ -1,6 +1,6 @@ import pytest -from guppy.compiler import guppy +from guppy.decorator import guppy from tests.integration.util import functional, _ diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py index b6c4d359..b406c932 100644 --- a/tests/integration/test_higher_order.py +++ b/tests/integration/test_higher_order.py @@ -1,6 +1,7 @@ from typing import Callable -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule def test_basic(validate): diff --git a/tests/integration/test_if.py b/tests/integration/test_if.py index 6ea6cafc..4f09b295 100644 --- a/tests/integration/test_if.py +++ b/tests/integration/test_if.py @@ -1,6 +1,7 @@ import pytest -from guppy.compiler import guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule def test_if_no_else(validate): diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index 71ebcbb5..0d744e06 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -1,8 +1,9 @@ -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.prelude.quantum import Qubit import guppy.prelude.quantum as quantum -from guppy.prelude.quantum import h, cx +from guppy.prelude.quantum import h, cx, measure def test_id(validate): @@ -13,7 +14,7 @@ def test_id(validate): def test(q: Qubit) -> Qubit: return q - validate(module.compile(True)) + validate(module.compile()) def test_assign(validate): @@ -26,7 +27,7 @@ def test(q: Qubit) -> Qubit: s = r return s - validate(module.compile(True)) + validate(module.compile()) def test_linear_return_order(validate): @@ -38,18 +39,18 @@ def test_linear_return_order(validate): def test(q: Qubit) -> tuple[Qubit, bool]: return measure(q) - validate(module.compile(True)) + validate(module.compile()) def test_interleave(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def f(q1: Qubit, q2: Qubit) -> tuple[Qubit, Qubit]: ... - @guppy(module) + @guppy.declare(module) def g(q1: Qubit, q2: Qubit) -> tuple[Qubit, Qubit]: ... @@ -60,14 +61,14 @@ def test(a: Qubit, b: Qubit, c: Qubit, d: Qubit) -> tuple[Qubit, Qubit, Qubit, Q b, c = g(b, c) return a, b, c, d - validate(module.compile(True)) + validate(module.compile()) def test_if(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def new() -> Qubit: ... @@ -80,14 +81,14 @@ def test(b: bool) -> Qubit: q = new() return q - validate(module.compile(True)) + validate(module.compile()) def test_if_return(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def new() -> Qubit: ... @@ -100,14 +101,14 @@ def test(b: bool) -> Qubit: q = new() return q - validate(module.compile(True)) + validate(module.compile()) def test_measure(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def measure(q: Qubit) -> bool: ... @@ -116,14 +117,14 @@ def test(q: Qubit, x: int) -> int: b = measure(q) return x - validate(module.compile(True)) + validate(module.compile()) def test_return_call(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def op(q: Qubit) -> Qubit: ... @@ -131,7 +132,7 @@ def op(q: Qubit) -> Qubit: def test(q: Qubit) -> Qubit: return op(q) - validate(module.compile(True)) + validate(module.compile()) def test_while(validate): @@ -145,7 +146,7 @@ def test(q: Qubit, i: int) -> Qubit: q = h(q) return q - validate(module.compile(True)) + validate(module.compile()) def test_while_break(validate): @@ -161,7 +162,7 @@ def test(q: Qubit, i: int) -> Qubit: break return q - validate(module.compile(True)) + validate(module.compile()) def test_while_continue(validate): @@ -177,18 +178,18 @@ def test(q: Qubit, i: int) -> Qubit: q = h(q) return q - validate(module.compile(True)) + validate(module.compile()) def test_while_reset(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def new_qubit() -> Qubit: ... - @guppy(module) + @guppy.declare(module) def measure() -> bool: ... @@ -208,15 +209,15 @@ def test_rus(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def measure(q: Qubit) -> bool: ... - @guppy(module) + @guppy.declare(module) def qalloc() -> Qubit: ... - @guppy(module) + @guppy.declare(module) def t(q: Qubit) -> Qubit: ... diff --git a/tests/integration/test_nested.py b/tests/integration/test_nested.py index caac4e0d..fe58523f 100644 --- a/tests/integration/test_nested.py +++ b/tests/integration/test_nested.py @@ -1,4 +1,4 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy def test_basic(validate): diff --git a/tests/integration/test_programs.py b/tests/integration/test_programs.py index d445a10e..01c2a6fd 100644 --- a/tests/integration/test_programs.py +++ b/tests/integration/test_programs.py @@ -1,4 +1,5 @@ -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.integration.util import functional, _ @@ -53,4 +54,4 @@ def is_odd(x: int) -> bool: return False return is_even(x - 1) - validate(module.compile(exit_on_error=True)) + validate(module.compile()) diff --git a/tests/integration/test_unused.py b/tests/integration/test_unused.py index 0ad04c84..772e6b60 100644 --- a/tests/integration/test_unused.py +++ b/tests/integration/test_unused.py @@ -1,6 +1,6 @@ import pytest -from guppy.compiler import guppy +from guppy.decorator import guppy """ All sorts of weird stuff is allowed when variables are not used. """ diff --git a/tests/integration/test_while.py b/tests/integration/test_while.py index 1af3939f..28d20dc0 100644 --- a/tests/integration/test_while.py +++ b/tests/integration/test_while.py @@ -1,4 +1,4 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy def test_infinite_loop(validate):