From 414dbf165ee573ccab040d64db1b02cd3f7aa5fe Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:55:45 +0000 Subject: [PATCH 01/77] Factor out checking code --- guppy/checker/__init__.py | 0 guppy/checker/cfg_checker.py | 205 ++++++++++++++++++++++ guppy/checker/core.py | 97 ++++++++++ guppy/checker/expr_checker.py | 322 ++++++++++++++++++++++++++++++++++ guppy/checker/func_checker.py | 171 ++++++++++++++++++ guppy/checker/stmt_checker.py | 119 +++++++++++++ 6 files changed, 914 insertions(+) create mode 100644 guppy/checker/__init__.py create mode 100644 guppy/checker/cfg_checker.py create mode 100644 guppy/checker/core.py create mode 100644 guppy/checker/expr_checker.py create mode 100644 guppy/checker/func_checker.py create mode 100644 guppy/checker/stmt_checker.py diff --git a/guppy/checker/__init__.py b/guppy/checker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py new file mode 100644 index 00000000..a8dc25e5 --- /dev/null +++ b/guppy/checker/cfg_checker.py @@ -0,0 +1,205 @@ +import collections +from copy import copy +from dataclasses import dataclass +from typing import Sequence + +from guppy.ast_util import line_col +from guppy.cfg.bb import BB, BBStatement +from guppy.cfg.cfg import CFG, BaseCFG +from guppy.checker.core import Globals, Context + +from guppy.checker.core import Variable +from guppy.checker.expr_checker import ExprSynthesizer, to_bool +from guppy.checker.stmt_checker import StmtChecker +from guppy.error import GuppyError +from guppy.guppy_types import GuppyType + + +VarRow = Sequence[Variable] + + +@dataclass(frozen=True) +class Signature: + """The signature of a basic block. + + Stores the input/output variables with their types. + """ + + input_row: VarRow + output_rows: Sequence[VarRow] # One for each successor + + @staticmethod + def empty() -> "Signature": + return Signature([], []) + + +@dataclass(eq=False) # Disable equality to recover hash from `object` +class CheckedBB(BB): + """Basic block annotated with an input and output type signature.""" + + sig: Signature = Signature.empty() + + +class CheckedCFG(BaseCFG[CheckedBB]): + input_tys: list[GuppyType] + output_ty: GuppyType + + def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None: + super().__init__([]) + self.input_tys = input_tys + self.output_ty = output_ty + + +def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedCFG: + """Type checks a control-flow graph. + + Annotates the basic blocks with input and output type signatures and removes + unreachable blocks. + """ + # First, we need to run program analysis + ass_before = set(v.name for v in inputs) + cfg.analyze(ass_before, ass_before) + + # We start by compiling the entry BB + checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) + checked_cfg.entry_bb = check_bb(cfg.entry_bb, checked_cfg, inputs, return_ty, globals) + compiled = {cfg.entry_bb: checked_cfg.entry_bb} + + # Visit all control-flow edges in BFS order. We can't just do a normal loop over + # all BBs since the input types for a BB are computed by checking a predecessor. + # We do BFS instead of DFS to get a better error ordering. + queue = collections.deque( + (checked_cfg.entry_bb, i, succ) for i, succ in enumerate(cfg.entry_bb.successors) + ) + while len(queue) > 0: + pred, num_output, bb = queue.popleft() + input_row = [Variable(v.name, v.ty, v.defined_at, None) for v in pred.sig.output_rows[num_output]] + + if bb in compiled: + # If the BB was already compiled, we just have to check that the signatures + # match. + check_rows_match(input_row, compiled[bb].sig.input_row, bb) + else: + # Otherwise, check the BB and enqueue its successors + checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) + queue += [ + (checked_bb, i, succ) for i, succ in enumerate(bb.successors) + ] + compiled[bb] = checked_bb + + # Link up BBs in the checked CFG + compiled[bb].predecessors.append(pred) + pred.successors[num_output] = compiled[bb] + + checked_cfg.bbs = list(compiled.values()) + checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable + checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} + checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} + checked_cfg.maybe_ass_before = {compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs} + return checked_cfg + + +def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedBB: + cfg = bb.cfg + + # For the entry BB we have to separately check that all used variables are + # defined. For all other BBs, this will be checked when compiling a predecessor. + if len(bb.predecessors) == 0: + for x, use in bb.vars.used.items(): + if x not in cfg.ass_before[bb] and x not in globals.values: + raise GuppyError(f"Variable `{x}` is not defined", use) + + # Check the basic block + ctx = Context(globals, {v.name: v for v in inputs}) + checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) + + # If we branch, we also have to check the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + bb.branch_pred, ty = ExprSynthesizer(ctx).synthesize(bb.branch_pred) + bb.branch_pred, _ = to_bool(bb.branch_pred, ty, ctx) + + for succ in bb.successors: + for x, use_bb in cfg.live_before[succ].items(): + # Check that the variables requested by the successor are defined + if x not in ctx.locals and x not in ctx.globals.values: + # If the variable is defined on *some* paths, we can give a more + # informative error message + if x in cfg.maybe_ass_before[use_bb]: + # TODO: This should be "Variable x is not defined when coming + # from {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` is not defined on all control-flow paths.", + use_bb.vars.used[x], + ) + raise GuppyError( + f"Variable `{x}` is not defined", use_bb.vars.used[x] + ) + + # We have to check that used linear variables are not being outputted + if x in ctx.locals: + var = ctx.locals[x] + if var.ty.linear and var.used: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + cfg.live_before[succ][x].vars.used[x], + [var.used], + ) + + # On the other hand, unused linear variables *must* be outputted + for x, var in ctx.locals.items(): + if var.ty.linear and not var.used and x not in cfg.live_before[succ]: + # TODO: This should be "Variable x with linear type ty is not + # used in {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is " + "not used on all control-flow paths", + var.defined_at, + ) + + # Finally, we need to compute the signature of the basic block + outputs = [ + [ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals] + for succ in bb.successors + ] + + # Also prepare the successor list so we can fill it in later + checked_bb = CheckedBB(bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs)) + checked_bb.successors = [None] * len(bb.successors) # type: ignore + checked_bb.branch_pred = bb.branch_pred + return checked_bb + + +def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: + """Checks that the types of two rows match up. + + Otherwise, an error is thrown, alerting the user that a variable has different + types on different control-flow paths. + """ + map1, map2 = {v.name: v for v in row1}, {v.name: v for v in row2} + assert map1.keys() == map2.keys() + for x in map1: + v1, v2 = map1[x], map2[x] + if v1.ty != v2.ty: + # In the error message, we want to mention the variable that was first + # defined at the start. + if ( + v1.defined_at + and v2.defined_at + and line_col(v2.defined_at) < line_col(v1.defined_at) + ): + v1, v2 = v2, v1 + # We shouldn't mention temporary variables (starting with `%`) + # in error messages: + ident = ( + "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" + ) + raise GuppyError( + f"{ident} can refer to different types: " + f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", + bb.cfg.live_before[bb][v1.name].vars.used[v1.name], + [v1.defined_at, v2.defined_at], + ) diff --git a/guppy/checker/core.py b/guppy/checker/core.py new file mode 100644 index 00000000..461d150c --- /dev/null +++ b/guppy/checker/core.py @@ -0,0 +1,97 @@ +import ast +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import NamedTuple, Optional, Union + +from guppy.ast_util import AstNode +from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType, NoneType, \ + BoolType +from guppy.nodes import GlobalCall + + +@dataclass +class Variable: + """Class holding data associated with a variable.""" + + name: str + ty: GuppyType + defined_at: Optional[AstNode] + used: Optional[AstNode] + + +@dataclass +class CallableVariable(ABC, Variable): + """Abstract base class for global variables that can be called.""" + + ty: FunctionType + + @abstractmethod + def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context") -> ast.expr: + """Checks the return type of a function call against a given type.""" + + @abstractmethod + def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]: + """Synthesizes the return type of a function call.""" + + +class Globals(NamedTuple): + """Collection of names that are available on module-level. + + Separately stores names that are bound to values (i.e. module-level functions or + constants), to types, or to instance functions belonging to types. + """ + + values: dict[str, Variable] + types: dict[str, type[GuppyType]] + + @staticmethod + def default() -> "Globals": + """Generates a `Globals` instance that is populated with all core types""" + tys: dict[str, type[GuppyType]] = { + FunctionType.name: FunctionType, + TupleType.name: TupleType, + SumType.name: SumType, + NoneType.name: NoneType, + BoolType.name: BoolType, + } + return Globals({}, tys) + + def get_instance_func(self, ty: GuppyType, name: str) -> Optional[CallableVariable]: + """Looks up an instance function with a given name for a type. + + Returns `None` if the name doesn't exist or isn't a function. + """ + qualname = qualified_name(ty.__class__, name) + if qualname in self.values: + val = self.values[qualname] + if isinstance(val, CallableVariable): + return val + return None + + def __or__(self, other: "Globals") -> "Globals": + return Globals( + self.values | other.values, + self.types | other.types, + ) + + def __ior__(self, other: "Globals") -> "Globals": + self.values.update(other.values) + self.types.update(other.types) + return self + + +# Local variable mapping +Locals = dict[str, Variable] + + +class Context(NamedTuple): + """The type checking context.""" + + globals: Globals + locals: Locals + + +def qualified_name(ty: Union[type[GuppyType], str], name: str) -> str: + """Returns a qualified name for an instance function on a type.""" + ty_name = ty if isinstance(ty, str) else ty.name + return f"{ty_name}.{name}" diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py new file mode 100644 index 00000000..9b2e1ecf --- /dev/null +++ b/guppy/checker/expr_checker.py @@ -0,0 +1,322 @@ +import ast +from ast import NodeTransformer +from typing import Optional, Union, NoReturn, Any + +from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type, \ + get_type_opt +from guppy.checker.core import Context, CallableVariable, Globals +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.guppy_types import GuppyType, TupleType, FunctionType, BoolType +from guppy.hugr import val +from guppy.nodes import LocalName, GlobalName, LocalCall + +# Mapping from unary AST op to dunder method and display name +unary_table: dict[type[ast.unaryop], tuple[str, str]] = { + ast.UAdd: ("__pos__", "+"), + ast.USub: ("__neg__", "-"), + ast.Invert: ("__invert__", "~"), +} + +# Mapping from binary AST op to left dunder method, right dunder method and display name +AstOp = Union[ast.operator, ast.cmpop] +binary_table: dict[type[AstOp], tuple[str, str, str]] = { + ast.Add: ("__add__", "__radd__", "+"), + ast.Sub: ("__sub__", "__rsub__", "-"), + ast.Mult: ("__mul__", "__rmul__", "*"), + ast.Div: ("__truediv__", "__rtruediv__", "/"), + ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), + ast.Mod: ("__mod__", "__rmod__", "%"), + ast.Pow: ("__pow__", "__rpow__", "**"), + ast.LShift: ("__lshift__", "__rlshift__", "<<"), + ast.RShift: ("__rshift__", "__rrshift__", ">>"), + ast.BitOr: ("__or__", "__ror__", "|"), + ast.BitXor: ("__xor__", "__rxor__", "^"), + ast.BitAnd: ("__and__", "__rand__", "&"), + ast.MatMult: ("__matmul__", "__rmatmul__", "@"), + ast.Eq: ("__eq__", "__eq__", "=="), + ast.NotEq: ("__neq__", "__neq__", "!="), + ast.Lt: ("__lt__", "__gt__", "<"), + ast.LtE: ("__le__", "__ge__", "<="), + ast.Gt: ("__gt__", "__lt__", ">"), + ast.GtE: ("__ge__", "__le__", ">="), +} + + +class ExprChecker(AstVisitor[ast.expr]): + """Checks an expression against a type and produces a new type-annotated AST""" + + ctx: Context + + # Name for the kind of term we are currently checking against (used in errors). + # For example, "argument", "return value", or in general "expression". + _kind: str + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + self._kind = "expression" + + def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Optional[AstNode] = None) -> NoReturn: + """Raises a type error indicating that the type doesn't match.""" + if not isinstance(actual, GuppyType): + loc = loc or actual + _, actual = self._synthesize(actual) + assert loc is not None + raise GuppyTypeError( + f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc + ) + + def check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + """Checks an expression against a type. + + Returns a new desugared expression with type annotations. + """ + old_kind = self._kind + self._kind = kind or self._kind + expr = self.visit(expr, ty) + self._kind = old_kind + return with_type(ty, expr) + + def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Invokes the type synthesiser""" + return ExprSynthesizer(self.ctx).synthesize(node) + + def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: + if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): + return self._fail(ty, node) + for i, el in enumerate(node.elts): + node.elts[i] = self.check(el, ty.element_types[i]) + return node + + def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore + # Try to synthesize and then check if it matches the given type + node, synth = self._synthesize(node) + if synth != ty: + self._fail(ty, synth, node) + return node + + +class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): + ctx: Context + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + + def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Tries to synthesise a type for the given expression. + + Also returns a new desugared expression with type annotations. + """ + if ty := get_type_opt(node): + return node, ty + node, ty = self.visit(node) + return with_type(ty, node), ty + + def _check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + """Checks an expression against a given type""" + return ExprChecker(self.ctx).check(expr, ty, kind) + + def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: + ty = python_value_to_guppy_type(node.value, node, self.ctx.globals) + if ty is None: + raise GuppyError("Unsupported constant", node) + return node, ty + + def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + x = node.id + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is not None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + node, + [var.used], + ) + var.used = node + return with_loc(node, LocalName(id=x)), var.ty + elif x in self.ctx.globals.values: + # Cache value in the AST + value = self.ctx.globals.values[x] + return with_loc(node, GlobalName(id=x, value=value)), value.ty + raise InternalGuppyError( + f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " + f"been caught by program analysis!" + ) + + def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: + elems = [self.synthesize(elem) for elem in node.elts] + node.elts = [n for n, _ in elems] + return node, TupleType([ty for _, ty in elems]) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: + # We need to synthesise the argument type, so we can look up dunder methods + node.operand, op_ty = self.synthesize(node.operand) + + # Special case for the `not` operation since it is not implemented via a dunder + # method or control-flow + if isinstance(node.op, ast.Not): + node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx) + return node, bool_ty + + # Check all other unary expressions by calling out to instance dunder methods + op, display_name = unary_table[node.op.__class__] + func = self.ctx.globals.get_instance_func(op_ty, op) + if func is None: + raise GuppyTypeError( + f"Unary operator `{display_name}` not defined for argument of type " + f" `{op_ty}`", + node.operand, + ) + return func.synthesize_call([node.operand], node, self.ctx) + + def _synthesize_binary( + self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr + ) -> tuple[ast.expr, GuppyType]: + """Helper method to compile binary operators by calling out to dunder methods. + + For example, first try calling `__add__` on the left operand. If that fails, try + `__radd__` on the right operand. + """ + if op.__class__ not in binary_table: + raise GuppyError("This binary operation is not supported by Guppy.", op) + lop, rop, display_name = binary_table[op.__class__] + left_expr, left_ty = self.synthesize(left_expr) + right_expr, right_ty = self.synthesize(right_expr) + + if func := self.ctx.globals.get_instance_func(left_ty, lop): + try: + return func.synthesize_call([left_expr, right_expr], node, self.ctx) + except GuppyError: + pass + + if func := self.ctx.globals.get_instance_func(right_ty, rop): + try: + return func.synthesize_call([right_expr, left_expr], node, self.ctx) + except GuppyError: + pass + + raise GuppyTypeError( + f"Binary operator `{display_name}` not defined for arguments of type " + f"`{left_ty}` and `{right_ty}`", + node, + ) + + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: + return self._synthesize_binary(node.left, node.right, node.op, node) + + def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: + if len(node.comparators) != 1 or len(node.ops) != 1: + raise InternalGuppyError( + "BB contains chained comparison. Should have been removed during CFG " + "construction." + ) + left_expr, [op], [right_expr] = node.left, node.ops, node.comparators + return self._synthesize_binary(left_expr, right_expr, op, node) + + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: + if len(node.keywords) > 0: + raise GuppyError( + "Argument passing by keyword is not supported", node.keywords[0] + ) + node.func, ty = self.synthesize(node.func) + + # First handle direct calls of user-defined functions and extension functions + if isinstance(node.func, GlobalName) and isinstance(node.func.value, CallableVariable): + return node.func.value.synthesize_call(node.args, node, self.ctx) + + # Otherwise, it must be a function as a higher-order value + if isinstance(ty, FunctionType): + args, return_ty = synthesize_call(ty, node.args, node, self.ctx) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif f := self.ctx.globals.get_instance_func(ty, "__call__"): + return f.synthesize_call(node.args, node, self.ctx) + else: + raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `NamedExpr`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + + def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `BoolOp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `IfExp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + +def check_num_args(exp: int, act: int, node: AstNode) -> None: + """Checks that the correct number of arguments have been passed to a function.""" + if act < exp: + raise GuppyTypeError( + f"Not enough arguments passed (expected {exp}, got {act})", node + ) + if exp < act: + if isinstance(node, ast.Call): + raise GuppyTypeError("Unexpected argument", node.args[exp]) + raise GuppyTypeError( + f"Too many arguments passed (expected {exp}, got {act})", node + ) + + +def synthesize_call(func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[list[ast.expr], GuppyType]: + """Synthesizes the return type of a function call""" + check_num_args(len(func_ty.args), len(args), node) + for i, arg in enumerate(args): + args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") + return args, func_ty.returns + + +def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> list[ast.expr]: + """Checks the return type of a function call against a given type""" + args, return_ty = synthesize_call(func_ty, args, node, ctx) + if return_ty != ty: + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `{return_ty}`", node + ) + return args + + +def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr, GuppyType]: + """Tries to turn a node into a bool""" + if isinstance(node_ty, BoolType): + return node, node_ty + + func = ctx.globals.get_instance_func(node_ty, "__bool__") + if func is None: + raise GuppyTypeError( + f"Expression of type `{node_ty}` cannot be interpreted as a `bool`", + node, + ) + + # We could check the return type against bool, but we can give a better error + # message if we synthesise and compare to bool by hand + call, return_ty = func.synthesize_call([node], node, ctx) + if not isinstance(return_ty, BoolType): + raise GuppyTypeError( + f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", + node + ) + return call, return_ty + + +def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Optional[GuppyType]: + """Turns a primitive Python value into a Guppy type. + + Returns `None` if the Python value cannot be represented in Guppy. + """ + if isinstance(v, bool): + return globals.types["bool"].build(node=node) + elif isinstance(v, int): + return globals.types["int"].build(node=node) + if isinstance(v, float): + return globals.types["float"].build(node=node) + return None + diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py new file mode 100644 index 00000000..87cf9ee6 --- /dev/null +++ b/guppy/checker/func_checker.py @@ -0,0 +1,171 @@ +import ast +from dataclasses import dataclass, field +from typing import Mapping + +from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc +from guppy.cfg.bb import BB +from guppy.cfg.builder import CFGBuilder +from guppy.checker.core import Variable, Globals, Context, CallableVariable +from guppy.checker.cfg_checker import check_cfg, CheckedCFG +from guppy.checker.expr_checker import synthesize_call, check_call +from guppy.error import GuppyError +from guppy.guppy_types import FunctionType, type_from_ast, NoneType, GuppyType, \ + type_to_row +from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, LocalName, \ + NestedFunctionDef + + +@dataclass +class DefinedFunction(CallableVariable): + """A user-defined function""" + ty: FunctionType + defined_at: ast.FunctionDef + + @staticmethod + def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DefinedFunction": + ty = check_signature(func_def, globals) + return DefinedFunction(name, ty, func_def, None) + + def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall: + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) + + def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]: + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty + + +@dataclass +class CheckedFunction(DefinedFunction): + """Type checked version of a user-defined function""" + + cfg: CheckedCFG + + +def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFunction: + """Type checks a top-level function definition.""" + func_def = func.defined_at + args = func_def.args.args + returns_none = isinstance(func.ty.returns, NoneType) + assert func.ty.arg_names is not None + + cfg = CFGBuilder().build(func_def.body, returns_none, globals) + inputs = [ + Variable(x, ty, loc, None) + for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) + ] + cfg = check_cfg(cfg, inputs, func.ty.returns, globals) + return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) + + +def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) -> CheckedNestedFunctionDef: + """Type checks a local (nested) function definition.""" + func_ty = check_signature(func_def, ctx.globals) + assert func_ty.arg_names is not None + + # We've already built the CFG for this function while building the CFG of the + # enclosing function + cfg = func_def.cfg + + # Find captured variables + parent_cfg = bb.cfg + def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() + maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] + cfg.analyze(def_ass_before, maybe_ass_before) + captured = { + x: ctx.locals[x] + for x in cfg.live_before[cfg.entry_bb] + if x not in func_ty.arg_names and x in ctx.locals + } + + # Captured variables may not be linear + for v in captured.values(): + if v.ty.linear: + x = v.name + using_bb = cfg.live_before[cfg.entry_bb][x] + raise GuppyError( + f"Variable `{x}` with linear type `{v.ty}` may not be used here " + f"because it was defined in an outer scope (at {{0}})", + using_bb.vars.used[x], + [v.defined_at], + ) + + # Captured variables may never be assigned to + for bb in cfg.bbs: + for v in captured.values(): + x = v.name + if x in bb.vars.assigned: + raise GuppyError( + f"Variable `{x}` defined in an outer scope (at {{0}}) may not " + f"be assigned to", + bb.vars.assigned[x], + [v.defined_at], + ) + + # Construct inputs for checking the body CFG + inputs = list(captured.values()) + [Variable(x, ty, func_def.args.args[i], None) for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args))] + globals = ctx.globals + + # Check if the body contains a recursive occurrence of the function name + if func_def.name in cfg.live_before[cfg.entry_bb]: + if len(captured) == 0: + # If there are no captured vars, we treat the function like a global name + func = DefinedFunction(func_def.name, func_ty, func_def, None) + globals = ctx.globals | Globals({func_def.name: func}, {}) + + else: + # Otherwise, we treat it like a local name + inputs.append(Variable(func_def.name, func_def.ty, func_def, None)) + + checked_cfg = check_cfg(cfg, inputs, func_ty.returns, globals) + checked_def = CheckedNestedFunctionDef( + checked_cfg, + func_ty, + captured, + name=func_def.name, + args=func_def.args, + body=func_def.body, + decorator_list=func_def.decorator_list, + returns=func_def.returns, + type_comment=func_def.type_comment, + ) + return with_loc(func_def, checked_def) + + +def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: + """Checks the signature of a function definition and returns the corresponding + Guppy type.""" + if len(func_def.args.posonlyargs) != 0: + raise GuppyError( + "Positional-only parameters not supported", func_def.args.posonlyargs[0] + ) + if len(func_def.args.kwonlyargs) != 0: + raise GuppyError( + "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] + ) + if func_def.args.vararg is not None: + raise GuppyError("*args not supported", func_def.args.vararg) + if func_def.args.kwarg is not None: + raise GuppyError("**kwargs not supported", func_def.args.kwarg) + if func_def.returns is None: + # TODO: Error location is incorrect + if all(r.value is None for r in return_nodes_in_ast(func_def)): + raise GuppyError( + "Return type must be annotated. Try adding a `-> None` annotation.", + func_def, + ) + raise GuppyError("Return type must be annotated", func_def) + + arg_tys = [] + arg_names = [] + for i, arg in enumerate(func_def.args.args): + if arg.annotation is None: + raise GuppyError("Argument type must be annotated", arg) + ty = type_from_ast(arg.annotation, globals) + arg_tys.append(ty) + arg_names.append(arg.arg) + + ret_type = type_from_ast(func_def.returns, globals) + return FunctionType(arg_tys, ret_type, arg_names) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py new file mode 100644 index 00000000..8f772b2a --- /dev/null +++ b/guppy/checker/stmt_checker.py @@ -0,0 +1,119 @@ +import ast +from typing import Sequence + +from guppy.ast_util import set_location_from, with_loc, AstVisitor +from guppy.cfg.bb import BB, BBStatement +from guppy.checker.core import Variable, Context +from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.guppy_types import GuppyType, TupleType, type_from_ast, NoneType +from guppy.nodes import NestedFunctionDef + + +class StmtChecker(AstVisitor[BBStatement]): + + ctx: Context + bb: BB + return_ty: GuppyType + + def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: + self.ctx = ctx + self.bb = bb + self.return_ty = return_ty + + def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: + return [self.visit(s) for s in stmts] + + def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + return ExprSynthesizer(self.ctx).synthesize(node) + + def _check_expr(self, node: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + return ExprChecker(self.ctx).check(node, ty, kind) + + def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: + """Helper function to check assignments with patterns.""" + # Easiest case is if the LHS pattern is a single variable. + if isinstance(lhs, ast.Name): + # Check if we override an unused linear variable + x = lhs.id + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + self.ctx.locals[x] = Variable(x, ty, node, None) + # The only other thing we support right now are tuples + elif isinstance(lhs, ast.Tuple): + tys = ty.element_types if isinstance(ty, TupleType) else [ty] + n, m = len(lhs.elts), len(tys) + if n != m: + raise GuppyTypeError( + f"{'Too many' if n < m else 'Not enough'} values to unpack " + f"(expected {n}, got {m})", + node, + ) + for pat, el_ty in zip(lhs.elts, tys): + self._check_assign(pat, el_ty, node) + # TODO: Python also supports assignments like `[a, b] = [1, 2]` or + # `a, *b = ...`. The former would require some runtime checks but + # the latter should be easier to do (unpack and repack the rest). + else: + raise GuppyError("Assignment pattern not supported", lhs) + + def visit_Assign(self, node: ast.Assign) -> ast.stmt: + if len(node.targets) > 1: + # This is the case for assignments like `a = b = 1` + raise GuppyError("Multi assignment not supported", node) + + [target] = node.targets + node.value, ty = self._synth_expr(node.value) + self._check_assign(target, ty, node) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: + if node.value is None: + raise GuppyError(f"Variable declaration is not supported. Assignment is required", node) + ty = type_from_ast(node.annotation, self.ctx.globals) + node.value = self._check_expr(node.value, ty) + self._check_assign(node.target, ty, node) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: + bin_op = with_loc(node, ast.BinOp(left=node.target, op=node.op, right=node.value)) + assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) + return self.visit_Assign(assign) + + def visit_Expr(self, node: ast.Expr) -> ast.stmt: + # An expression statement where the return value is discarded + node.value, _ = self._synth_expr(node.value) + return node + + def visit_Return(self, node: ast.Return) -> ast.stmt: + if node.value is not None: + node.value = self._check_expr(node.value, self.return_ty, "return value") + elif not isinstance(self.return_ty, NoneType): + raise GuppyTypeError( + f"Expected return value of type `{self.return_ty}`", None + ) + return node + + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: + from guppy.checker.func_checker import check_nested_func_def + + func_def = check_nested_func_def(node, self.bb, self.ctx) + self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def, None) + return func_def + + def visit_If(self, node: ast.If) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_While(self, node: ast.While) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Break(self, node: ast.Break) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Continue(self, node: ast.Continue) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") From bba9c96894c4a5c6e70c487c1fe84d3bf4bdd5f8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:56:24 +0000 Subject: [PATCH 02/77] Factor out graph generation code --- guppy/compiler/__init__.py | 0 guppy/compiler/cfg_compiler.py | 161 ++++++++++++++++++++++++++++++++ guppy/compiler/core.py | 111 ++++++++++++++++++++++ guppy/compiler/expr_compiler.py | 116 +++++++++++++++++++++++ guppy/compiler/func_compiler.py | 85 +++++++++++++++++ 5 files changed, 473 insertions(+) create mode 100644 guppy/compiler/__init__.py create mode 100644 guppy/compiler/cfg_compiler.py create mode 100644 guppy/compiler/core.py create mode 100644 guppy/compiler/expr_compiler.py create mode 100644 guppy/compiler/func_compiler.py diff --git a/guppy/compiler/__init__.py b/guppy/compiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py new file mode 100644 index 00000000..bd31be7e --- /dev/null +++ b/guppy/compiler/cfg_compiler.py @@ -0,0 +1,161 @@ +import functools +from typing import Sequence + +from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature +from guppy.checker.core import Variable +from guppy.compiler.core import CompiledGlobals, is_return_var, DFContainer, return_var, \ + PortVariable +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.compiler.stmt_compiler import StmtCompiler +from guppy.guppy_types import TupleType, SumType, type_to_row +from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV + + +def compile_cfg(cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals) -> None: + """Compiles a CFG to Hugr.""" + insert_return_vars(cfg) + + blocks: dict[CheckedBB, CFNode] = {} + for bb in cfg.bbs: + blocks[bb] = compile_bb(bb, graph, parent, globals) + for bb in cfg.bbs: + for succ in bb.successors: + graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) + + +def compile_bb(bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals) -> CFNode: + """Compiles a single basic block to Hugr.""" + inputs = sort_vars(bb.sig.input_row) + + # The exit BB is completely empty + if len(bb.successors) == 0: + assert len(bb.statements) == 0 + return graph.add_exit([v.ty for v in inputs], parent) + + # Otherwise, we use a regular `Block` node + block = graph.add_block(parent) + + # Add input node and compile the statements + inp = graph.add_input(output_tys=[v.ty for v in inputs], parent=block) + dfg = DFContainer( + block, + { + v.name: PortVariable(v.name, inp.out_port(i), v.defined_at, None) + for (i, v) in enumerate(inputs) + }, + ) + dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) + + # If we branch, we also have to compile the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg) + else: + # Even if we don't branch, we still have to add a `Sum(())` predicates + unit = graph.add_make_tuple([], parent=block).out_port(0) + branch_port = graph.add_tag( + variants=[TupleType([])], tag=0, inp=unit, parent=block + ).out_port(0) + + # Finally, we have to add the block output. + outputs: Sequence[Variable] + if len(bb.successors) == 1: + # The easy case is if we don't branch: We just output all variables that are + # specified by the signature + [outputs] = bb.sig.output_rows + else: + # If we branch and the branches use the same variables, then we can use a + # regular output + first, *rest = bb.sig.output_rows + if all({v.name for v in first} == {v.name for v in r} for r in rest): + outputs = first + else: + # Otherwise, we have to output a Sum-type predicate: We put all non-linear + # variables into the branch predicate and all linear variables in the normal + # output (since they are shared between all successors). This is in line + # with the ordering on variables which puts linear variables at the end. + # The only exception are return vars which must be outputted in order. + branch_port = choose_vars_for_pred( + graph=graph, + pred=branch_port, + output_vars=[ + [v for v in row if not v.ty.linear or is_return_var(v.name)] + for row in bb.sig.output_rows + ], + dfg=dfg, + ) + outputs = [v for v in first if v.ty.linear and not is_return_var(v.name)] + + graph.add_output( + inputs=[branch_port] + [dfg[v.name].port for v in sort_vars(outputs)], + parent=block, + ) + return block + + +def insert_return_vars(cfg: CheckedCFG) -> None: + """Patches a CFG by annotating dummy return variables in the BB signatures. + + The statement compiler turns `return` statements into assignments of dummy variables + `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are + correctly outputted. + """ + return_vars = [Variable(return_var(i), ty, None, None) for i, ty in enumerate(type_to_row(cfg.output_ty))] + # Before patching, the exit BB shouldn't take any inputs + assert len(cfg.exit_bb.sig.input_row) == 0 + cfg.exit_bb.sig = Signature(return_vars, cfg.exit_bb.sig.output_rows) + # Also patch the predecessors + for pred in cfg.exit_bb.predecessors: + # The exit BB will be the only successor + assert len(pred.sig.output_rows) == 1 and len(pred.sig.output_rows[0]) == 0 + pred.sig = Signature(pred.sig.input_row, [return_vars]) + + +def choose_vars_for_pred( + graph: Hugr, pred: OutPortV, output_vars: list[VarRow], dfg: DFContainer +) -> OutPortV: + """Selects an output based on a predicate. + + Given `pred: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, constructs + a predicate value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. + """ + assert isinstance(pred.ty, SumType) + assert len(pred.ty.element_types) == len(output_vars) + tuples = [ + graph.add_make_tuple( + inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], parent=dfg.node + ).out_port(0) + for vs in output_vars + ] + tys = [t.ty for t in tuples] + conditional = graph.add_conditional( + cond_input=pred, inputs=tuples, parent=dfg.node + ) + for i, ty in enumerate(tys): + case = graph.add_case(conditional) + inp = graph.add_input(output_tys=tys, parent=case).out_port(i) + tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) + graph.add_output(inputs=[tag], parent=case) + return conditional.add_out_port(SumType(tys)) + + +def compare_var(x: Variable, y: Variable) -> int: + """Defines a `<` order on variables. + + We use this to determine in which order variables are outputted from basic blocks. + We need to output linear variables at the end, so we do a lexicographic ordering of + linearity and name. The only exception are return vars which must be outputted in + order. + """ + if is_return_var(x.name) and is_return_var(y.name): + return -1 if x.name < y.name else 1 + return -1 if (x.ty.linear, x.name) < (y.ty.linear, y.name) else 1 + + +def sort_vars(row: VarRow) -> list[Variable]: + """Sorts a row of variables. + + This determines the order in which they are outputted from a BB. + """ + return sorted(row, key=functools.cmp_to_key(compare_var)) + diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py new file mode 100644 index 00000000..1d0269ed --- /dev/null +++ b/guppy/compiler/core.py @@ -0,0 +1,111 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, Iterator, NamedTuple + +from guppy.ast_util import AstNode +from guppy.checker.core import Variable, CallableVariable +from guppy.guppy_types import FunctionType +from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr + + +@dataclass +class PortVariable(Variable): + """Represents a local variable in a dataflow graph. + + Local variables are associated with a port in the Hugr. + """ + + port: OutPortV + + def __init__(self, name: str, port: OutPortV, defined_at: Optional[AstNode], used: Optional[AstNode] = None): + super().__init__(name, port.ty, defined_at, used) + object.__setattr__(self, "port", port) + + +class CompiledVariable(ABC, Variable): + """Abstract base class for compiled global module-level variables.""" + + @abstractmethod + def load(self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode) -> OutPortV: + """Loads the variable as a value into a local dataflow graph.""" + + +class CompiledFunction(CompiledVariable, CallableVariable): + """Abstract base class a global module-level function.""" + + ty: FunctionType + + @abstractmethod + def compile_call( + self, + args: list[OutPortV], + dfg: "DFContainer", + graph: Hugr, + globals: "CompiledGlobals", + node: AstNode, + ) -> list[OutPortV]: + """Compiles a call to the function.""" + + +CompiledGlobals = dict[str, CompiledVariable] +CompiledLocals = dict[str, PortVariable] + + +@dataclass +class DFContainer: + """A dataflow graph under construction. + + This class is passed through the entire compilation pipeline and stores the node + whose dataflow child-graph is currently being constructed as well as all live local + variables. Note that the variable map is mutated in-place and always reflects the + current compilation state. + """ + + node: DFContainingNode + locals: CompiledLocals + + def __getitem__(self, item: str) -> PortVariable: + return self.locals[item] + + def __setitem__(self, key: str, value: PortVariable) -> None: + self.locals[key] = value + + def __iter__(self) -> Iterator[PortVariable]: + return iter(self.locals.values()) + + def __contains__(self, item: str) -> bool: + return item in self.locals + + def __copy__(self) -> "DFContainer": + # Make a copy of the var map so that mutating the copy doesn't + # mutate our variable mapping + return DFContainer(self.node, self.locals.copy()) + + def get_var(self, name: str) -> Optional[PortVariable]: + return self.locals.get(name, None) + + +class CompilerBase(ABC): + """Base class for the Guppy compiler.""" + + graph: Hugr + globals: CompiledGlobals + + def __init__(self, graph: Hugr, globals: CompiledGlobals) -> None: + self.graph = graph + self.globals = globals + + +def return_var(n: int) -> str: + """Name of the dummy variable for the n-th return value of a function. + + During compilation, we treat return statements like assignments of dummy variables. + For example, the statement `return e0, e1, e2` is treated like `%ret0 = e0 ; %ret1 = + e1 ; %ret2 = e2`. This way, we can reuse our existing mechanism for passing of live + variables between basic blocks.""" + return f"%ret{n}" + + +def is_return_var(x: str) -> bool: + """Checks whether the given name is a dummy return variable.""" + return x.startswith("%ret") diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py new file mode 100644 index 00000000..b5c4c65b --- /dev/null +++ b/guppy/compiler/expr_compiler.py @@ -0,0 +1,116 @@ +import ast +from typing import Any, Optional + +from guppy.ast_util import AstVisitor, get_type +from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction +from guppy.error import InternalGuppyError +from guppy.guppy_types import FunctionType, type_to_row, BoolType +from guppy.hugr import ops, val +from guppy.hugr.hugr import OutPortV +from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall + + +class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): + """A compiler from Guppy expressions to Hugr.""" + + dfg: DFContainer + + def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: + """Compiles an expression and returns a single port holding the output value.""" + self.dfg = dfg + with self.graph.parent(dfg.node): + res = self.visit(expr) + return res + + def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: + """Compiles a row expression and returns a list of ports, one for each value in + the row. + + On Python-level, we treat tuples like rows on top-level. However, nested tuples + are treated like regular Guppy tuples. + """ + return [self.compile(e, dfg) for e in expr_to_row(expr)] + + def visit_Constant(self, node: ast.Constant) -> OutPortV: + if value := python_value_to_hugr(node.value): + const = self.graph.add_constant(value, get_type(node)).out_port(0) + return self.graph.add_load_constant(const).out_port(0) + raise InternalGuppyError("Unsupported constant expression in compiler") + + def visit_LocalName(self, node: LocalName) -> OutPortV: + return self.dfg[node.id].port + + def visit_GlobalName(self, node: GlobalName) -> OutPortV: + return self.globals[node.id].load(self.dfg, self.graph, self.globals, node) + + def visit_Name(self, node: ast.Name) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Tuple(self, node: ast.Tuple) -> OutPortV: + return self.graph.add_make_tuple( + inputs=[self.visit(e) for e in node.elts] + ).out_port(0) + + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: + """Groups function return values into a tuple""" + if len(returns) != 1: + return self.graph.add_make_tuple(inputs=returns).out_port(0) + return returns[0] + + def visit_LocalCall(self, node: LocalCall) -> OutPortV: + func = self.visit(node.func) + assert isinstance(func.ty, FunctionType) + + args = [self.visit(arg) for arg in node.args] + call = self.graph.add_indirect_call(func, args) + rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.returns)))] + return self._pack_returns(rets) + + def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: + func = self.globals[node.func.name] + assert isinstance(func, CompiledFunction) + + args = [self.visit(arg) for arg in node.args] + rets = func.compile_call(args, self.dfg, self.graph, self.globals, node) + return self._pack_returns(rets) + + def visit_Call(self, node: ast.Call) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: + # The only case that is not desugared by the type checker is the `not` operation + # since it is not implemented via a dunder method + if isinstance(node.op, ast.Not): + arg = self.visit(node.operand) + return self.graph.add_node( + ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] + ).add_out_port(BoolType()) + + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_BinOp(self, node: ast.BinOp) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Compare(self, node: ast.Compare) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + +def expr_to_row(expr: ast.expr) -> list[ast.expr]: + """Turns an expression into a row expressions by unpacking top-level tuples.""" + return expr.elts if isinstance(expr, ast.Tuple) else [expr] + + +def python_value_to_hugr(v: Any) -> Optional[val.Value]: + """Turns a Python value into a Hugr value. + + Returns None if the Python value cannot be represented in Guppy. + """ + from guppy.prelude._internal import int_value, bool_value, float_value + + if isinstance(v, bool): + return bool_value(v) + elif isinstance(v, int): + return int_value(v) + elif isinstance(v, float): + return float_value(v) + return None diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py new file mode 100644 index 00000000..f0f62cb6 --- /dev/null +++ b/guppy/compiler/func_compiler.py @@ -0,0 +1,85 @@ +import ast +from dataclasses import dataclass, field +from typing import Mapping, Sequence + +from guppy.ast_util import AstNode +from guppy.checker.core import Variable +from guppy.checker.func_checker import CheckedFunction, DefinedFunction +from guppy.compiler.cfg_compiler import compile_cfg +from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer, \ + PortVariable +from guppy.guppy_types import type_to_row, FunctionType +from guppy.hugr.hugr import Hugr, DFContainingNode, OutPortV, DFContainingVNode +from guppy.nodes import CheckedNestedFunctionDef + + +@dataclass +class CompiledFunctionDef(DefinedFunction, CompiledFunction): + node: DFContainingVNode + + def load(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) + + def compile_call(self, args: list[OutPortV], dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> list[OutPortV]: + call = graph.add_call(self.node.out_port(0), args, dfg.node) + return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] + + +def compile_global_func_def(func: CheckedFunction, def_node: DFContainingVNode, graph: Hugr, globals: CompiledGlobals) -> CompiledFunctionDef: + """Compiles a top-level function definition to Hugr.""" + def_input = graph.add_input(parent=def_node) + cfg_node = graph.add_cfg( + def_node, inputs=[def_input.add_out_port(ty) for ty in func.ty.args] + ) + compile_cfg(func.cfg, graph, cfg_node, globals) + + # Add output node for the cfg + graph.add_output(inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], parent=def_node) + + return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) + + +def compile_local_func_def(func: CheckedNestedFunctionDef, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals) -> PortVariable: + """Compiles a local (nested) function definition to Hugr.""" + assert func.ty.arg_names is not None + + # Pick an order for the captured variables + captured = list(func.captured.values()) + + # Prepend captured variables to the function arguments + closure_ty = FunctionType( + [v.ty for v in captured] + list(func.ty.args), + func.ty.returns, + [v.name for v in captured] + list(func.ty.arg_names), + ) + + def_node = graph.add_def(closure_ty, dfg.node, func.name) + def_input = graph.add_input(parent=def_node) + input_ports = [def_input.add_out_port(ty) for ty in closure_ty.args] + + # If we have captured variables and the body contains a recursive occurrence of + # the function itself, then we provide the partially applied function as a local + # variable + if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: + loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0) + partial = graph.add_partial(loaded, [def_input.out_port(i) for i in range(len(captured))], def_node) + input_ports.append(partial.out_port(0)) + func.cfg.input_tys.append(func.ty) + else: + # Otherwise, we treat the function like a normal global variable + globals = globals | {func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node)} + + # Compile the CFG + cfg_node = graph.add_cfg(def_node, inputs=input_ports) + compile_cfg(func.cfg, graph, cfg_node, globals) + + # Add output node for the cfg + graph.add_output(inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], parent=def_node) + + # Finally, load the function into the local data-flow graph + loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) + if len(captured) > 0: + loaded = graph.add_partial(loaded, [dfg[v.name].port for v in captured], dfg.node).out_port(0) + + return PortVariable(func.name, loaded, func) + From eb1a3271bcb3bd06d51a0b79eb06d19bc88b4ef4 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:57:37 +0000 Subject: [PATCH 03/77] Update cfg --- guppy/cfg/bb.py | 89 ++---------- guppy/cfg/builder.py | 20 +-- guppy/cfg/cfg.py | 321 +++++-------------------------------------- 3 files changed, 57 insertions(+), 373 deletions(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index dec4485a..f42df90c 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -1,14 +1,14 @@ import ast +from abc import ABC from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING, Union, Any +from typing import Optional, TYPE_CHECKING, Union, Sequence +from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.compiler_base import RawVariable, return_var -from guppy.guppy_types import FunctionType -from guppy.hugr.hugr import CFNode +from guppy.nodes import NestedFunctionDef if TYPE_CHECKING: - from guppy.cfg.cfg import CFG + from guppy.cfg.cfg import BaseCFG @dataclass @@ -34,61 +34,24 @@ def update_used(self, node: ast.AST) -> None: self.used[name.id] = name -VarRow = Sequence[RawVariable] - - -@dataclass(frozen=True) -class Signature: - """The signature of a basic block. - - Stores the inout/output variables with their types. - """ - - input_row: VarRow - output_rows: Sequence[VarRow] # One for each successor - - -@dataclass(frozen=True) -class CompiledBB: - """The result of compiling a basic block. - - Besides the corresponding node in the graph, we also store the signature of the - basic block with type information. - """ - - node: CFNode - bb: "BB" - sig: Signature - - -class NestedFunctionDef(ast.FunctionDef): - cfg: "CFG" - ty: FunctionType - - def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.cfg = cfg - self.ty = ty - - BBStatement = Union[ast.Assign, ast.AugAssign, ast.Expr, ast.Return, NestedFunctionDef] @dataclass(eq=False) # Disable equality to recover hash from `object` -class BB: +class BB(ABC): """A basic block in a control flow graph.""" idx: int # Pointer to the CFG that contains this node - cfg: "CFG" + cfg: "BaseCFG[Self]" # AST statements contained in this BB statements: list[BBStatement] = field(default_factory=list) # Predecessor and successor BBs - predecessors: list["BB"] = field(default_factory=list) - successors: list["BB"] = field(default_factory=list) + predecessors: list[Self] = field(default_factory=list) + successors: list[Self] = field(default_factory=list) # If the BB has multiple successors, we need a predicate to decide to which one to # jump to @@ -107,13 +70,9 @@ def vars(self) -> VariableStats: assert self._vars is not None return self._vars - def compute_variable_stats(self, num_returns: int) -> None: - """Determines which variables are assigned/used in this BB. - - This also requires the expected number of returns of the whole CFG in order to - process `return` statements. - """ - visitor = VariableVisitor(self, num_returns) + def compute_variable_stats(self) -> None: + """Determines which variables are assigned/used in this BB.""" + visitor = VariableVisitor(self) for s in self.statements: visitor.visit(s) self._vars = visitor.stats @@ -121,26 +80,15 @@ def compute_variable_stats(self, num_returns: int) -> None: if self.branch_pred is not None: self._vars.update_used(self.branch_pred) - # In the `StatementCompiler`, we're going to turn return statements into - # assignments of dummy variables `%ret_xxx`. Thus, we have to register those - # variables as being used in the exit BB - if len(self.successors) == 0: - self._vars.used |= { - return_var(i): ast.Name(return_var(i), ast.Load) - for i in range(num_returns) - } - class VariableVisitor(ast.NodeVisitor): """Visitor that computes used and assigned variables in a BB.""" bb: BB stats: VariableStats - num_returns: int - def __init__(self, bb: BB, num_returns: int): + def __init__(self, bb: BB): self.bb = bb - self.num_returns = num_returns self.stats = VariableStats() def visit_Assign(self, node: ast.Assign) -> None: @@ -155,22 +103,13 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None: for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node - def visit_Return(self, node: ast.Return) -> None: - if node.value is not None: - self.stats.update_used(node.value) - - # In the `StatementCompiler`, we're going to turn return statements into - # assignments of dummy variables `%ret_xxx`. To make the liveness analysis work, - # we have to register those variables as being assigned here - self.stats.assigned |= {return_var(i): node for i in range(self.num_returns)} - def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # In order to compute the used external variables in a nested function # definition, we have to run live variable analysis first from guppy.cfg.analysis import LivenessAnalysis for bb in node.cfg.bbs: - bb.compute_variable_stats(len(node.ty.returns)) + bb.compute_variable_stats() live = LivenessAnalysis().run(node.cfg.bbs) # Only store used *external* variables: things defined in the current BB, as diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 80e98559..56c61310 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -3,11 +3,12 @@ from typing import Optional, Iterator, Union, NamedTuple from guppy.ast_util import set_location_from, AstVisitor -from guppy.cfg.bb import BB, NestedFunctionDef +from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG -from guppy.compiler_base import Globals +from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError - +from guppy.guppy_types import NoneType +from guppy.nodes import NestedFunctionDef # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -31,10 +32,9 @@ class CFGBuilder(AstVisitor[Optional[BB]]): """Constructs a CFG from ast nodes.""" cfg: CFG - num_returns: int globals: Globals - def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CFG: + def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG: """Builds a CFG from a list of ast nodes. We also require the expected number of return ports for the whole CFG. This is @@ -42,7 +42,6 @@ def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CF variables. """ self.cfg = CFG() - self.num_returns = num_returns self.globals = globals final_bb = self.visit_stmts( @@ -52,7 +51,7 @@ def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CF # If we're still in a basic block after compiling the whole body, we have to add # an implicit void return if final_bb is not None: - if num_returns > 0: + if not returns_none: raise GuppyError("Expected return statement", nodes[-1]) self.cfg.link(final_bb, self.cfg.exit_bb) @@ -166,10 +165,11 @@ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> Optional[BB]: def visit_FunctionDef( self, node: ast.FunctionDef, bb: BB, jumps: Jumps ) -> Optional[BB]: - from guppy.function import FunctionDefCompiler + from guppy.checker.func_checker import check_signature - func_ty = FunctionDefCompiler.validate_signature(node, self.globals) - cfg = CFGBuilder().build(node.body, len(func_ty.returns), self.globals) + func_ty = check_signature(node, self.globals) + returns_none = isinstance(func_ty.returns, NoneType) + cfg = CFGBuilder().build(node.body, returns_none, self.globals) new_node = NestedFunctionDef( cfg, diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py index ef6f0314..6bd1e4b0 100644 --- a/guppy/cfg/cfg.py +++ b/guppy/cfg/cfg.py @@ -1,5 +1,4 @@ -import collections -from typing import Optional +from typing import Optional, TypeVar, Generic from guppy.cfg.analysis import ( LivenessDomain, @@ -9,41 +8,50 @@ MaybeAssignmentDomain, Result, ) -from guppy.cfg.bb import ( - BB, - VarRow, - Signature, - CompiledBB, - BBStatement, -) -from guppy.compiler_base import DFContainer, Variable, Globals, is_return_var -from guppy.error import GuppyError, GuppyTypeError -from guppy.ast_util import line_col -from guppy.expression import ExpressionCompiler -from guppy.guppy_types import GuppyType, TupleType, SumType -from guppy.hugr.hugr import Node, Hugr, OutPortV -from guppy.statement import StatementCompiler +from guppy.cfg.bb import BB, BBStatement + +T = TypeVar("T", bound=BB) -class CFG: - """A control-flow graph of basic blocks.""" - bbs: list[BB] - entry_bb: BB - exit_bb: BB +class BaseCFG(Generic[T]): + """Abstract base class for control-flow graphs.""" + + bbs: list[T] + entry_bb: T + exit_bb: T live_before: Result[LivenessDomain] ass_before: Result[DefAssignmentDomain] maybe_ass_before: Result[MaybeAssignmentDomain] - def __init__(self) -> None: - self.bbs = [] - self.entry_bb = self.new_bb() - self.exit_bb = self.new_bb() + def __init__(self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None): + self.bbs = bbs + if entry_bb: + self.entry_bb = entry_bb + if exit_bb: + self.exit_bb = exit_bb self.live_before = {} self.ass_before = {} self.maybe_ass_before = {} + def analyze(self, def_ass_before: set[str], maybe_ass_before: set[str]) -> None: + for bb in self.bbs: + bb.compute_variable_stats() + self.live_before = LivenessAnalysis().run(self.bbs) + self.ass_before, self.maybe_ass_before = AssignmentAnalysis( + self.bbs, def_ass_before, maybe_ass_before + ).run_unpacked(self.bbs) + + +class CFG(BaseCFG[BB]): + """A control-flow graph of unchecked basic blocks.""" + + def __init__(self) -> None: + super().__init__([]) + self.entry_bb = self.new_bb() + self.exit_bb = self.new_bb() + def new_bb(self, *preds: BB, statements: Optional[list[BBStatement]] = None) -> BB: """Adds a new basic block to the CFG.""" bb = BB( @@ -58,266 +66,3 @@ def link(self, src_bb: BB, tgt_bb: BB) -> None: """Adds a control-flow edge between two basic blocks.""" src_bb.successors.append(tgt_bb) tgt_bb.predecessors.append(src_bb) - - def analyze( - self, num_returns: int, def_ass_before: set[str], maybe_ass_before: set[str] - ) -> None: - for bb in self.bbs: - bb.compute_variable_stats(num_returns) - self.live_before = LivenessAnalysis().run(self.bbs) - self.ass_before, self.maybe_ass_before = AssignmentAnalysis( - self.bbs, def_ass_before, maybe_ass_before - ).run_unpacked(self.bbs) - - def compile( - self, - graph: Hugr, - input_row: VarRow, - return_tys: list[GuppyType], - parent: Node, - globals: Globals, - ) -> None: - """Compiles the CFG.""" - - # First, we need to run program analysis - ass_before = {v.name for v in input_row} - self.analyze(len(return_tys), ass_before, ass_before) - - # We start by compiling the entry BB - entry_compiled = self._compile_bb( - self.entry_bb, input_row, return_tys, graph, parent, globals - ) - compiled = {self.entry_bb: entry_compiled} - - # Visit all control-flow edges in BFS order. We can't just do a normal loop over - # all BBs since the input types for a BB are computed by compiling a predecessor - queue = collections.deque( - (entry_compiled, i, succ) for i, succ in enumerate(self.entry_bb.successors) - ) - while len(queue) > 0: - pred, num_output, bb = queue.popleft() - out_row = pred.sig.output_rows[num_output] - - if bb in compiled: - # If the BB was already compiled, we just have to check that the - # signatures match. - self._check_rows_match(out_row, compiled[bb].sig.input_row, bb) - else: - # Otherwise, compile the BB and enqueue its successors - compiled_bb = self._compile_bb( - bb, out_row, return_tys, graph, parent, globals - ) - queue += [ - (compiled_bb, i, succ) for i, succ in enumerate(bb.successors) - ] - compiled[bb] = compiled_bb - - graph.add_edge( - pred.node.out_port(num_output), compiled[bb].node.in_port(None) - ) - - def _compile_bb( - self, - bb: BB, - input_row: VarRow, - return_tys: list[GuppyType], - graph: Hugr, - parent: Node, - globals: Globals, - ) -> CompiledBB: - """Compiles a single basic block.""" - - # The exit BB is completely empty - if len(bb.successors) == 0: - block = graph.add_exit(return_tys, parent) - return CompiledBB(block, bb, Signature(input_row, [])) - - # For the entry BB we have to separately check that all used variables are - # defined. For all other BBs, this will be checked when compiling a predecessor. - if len(bb.predecessors) == 0: - for x, use in bb.vars.used.items(): - if x not in self.ass_before[bb] and x not in globals.values: - raise GuppyError(f"Variable `{x}` is not defined", use) - - # Compile the basic block - block = graph.add_block(parent, num_successors=len(bb.successors)) - inp = graph.add_input(output_tys=[v.ty for v in input_row], parent=block) - dfg = DFContainer( - block, - { - v.name: Variable(v.name, inp.out_port(i), v.defined_at) - for (i, v) in enumerate(input_row) - }, - ) - stmt_compiler = StatementCompiler(graph, globals) - dfg = stmt_compiler.compile_stmts(bb.statements, bb, dfg, return_tys) - - # If we branch, we also have to compile the branch predicate - if len(bb.successors) > 1: - assert bb.branch_pred is not None - expr_compiler = ExpressionCompiler(graph, globals) - port = expr_compiler.compile(bb.branch_pred, dfg) - func = globals.get_instance_func(port.ty, "__bool__") - if func is None: - raise GuppyTypeError( - f"Expression of type `{port.ty}` cannot be interpreted as a `bool`", - bb.branch_pred, - ) - [branch_port] = func.compile_call( - [port], dfg, graph, globals, bb.branch_pred - ) - - for succ in bb.successors: - for x, use_bb in self.live_before[succ].items(): - # Check that the variable requested by the successor are defined - if x not in dfg and x not in globals.values: - # If the variable is defined on *some* paths, we can give a more - # informative error message - if x in self.maybe_ass_before[use_bb]: - # TODO: This should be "Variable x is not defined when coming - # from {bb}". But for this we need a way to associate BBs with - # source locations. - raise GuppyError( - f"Variable `{x}` is not defined on all control-flow paths.", - use_bb.vars.used[x], - ) - raise GuppyError( - f"Variable `{x}` is not defined", use_bb.vars.used[x] - ) - - # We have to check that used linear variables are not being outputted - if x in dfg: - var = dfg[x] - if var.ty.linear and var.used: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - self.live_before[succ][x].vars.used[x], - [var.used], - ) - - # On the other hand, unused linear variables *must* be outputted - for x, var in dfg.variables.items(): - if var.ty.linear and not var.used and x not in self.live_before[succ]: - # TODO: This should be "Variable x with linear type ty is not - # used in {bb}". But for this we need a way to associate BBs with - # source locations. - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is " - "not used on all control-flow paths", - var.defined_at, - ) - - # Finally, we have to add the block output. The easy case is if we don't branch: - # We just output the variables that are live in the successor - output_vars = sorted( - dfg[x] for x in self.live_before[bb.successors[0]] if x in dfg - ) - if len(bb.successors) == 1: - # Even if we don't branch, we still have to add a `Sum(())` predicate - unit = graph.add_make_tuple([], parent=block).out_port(0) - branch_port = graph.add_tag( - variants=[TupleType([])], tag=0, inp=unit, parent=block - ).out_port(0) - else: - # If we branch and the branches use different variables, we have to output a - # Sum-type predicate - first, *rest = bb.successors - if any( - self.live_before[r].keys() & dfg.variables.keys() - != self.live_before[first].keys() & dfg.variables.keys() - for r in rest - ): - # We put all non-linear variables into the branch predicate and all - # linear variables in the normal output (since they are shared between - # all successors). This is in line with the definition of `<` on - # variables which puts linear variables at the end. The only exception - # are return vars which must be outputted in order. - branch_port = self._choose_vars_for_pred( - graph=graph, - pred=branch_port, - output_vars=[ - sorted( - x - for x in self.live_before[succ] - if x in dfg and (not dfg[x].ty.linear or is_return_var(x)) - ) - for succ in bb.successors - ], - dfg=dfg, - ) - output_vars = sorted( - dfg[x] - # We can look at `successors[0]` here since all successors must have - # the same `live_before` linear variables - for x in self.live_before[bb.successors[0]] - if x in dfg and dfg[x].ty.linear and not is_return_var(x) - ) - - graph.add_output( - inputs=[branch_port] + [v.port for v in output_vars], parent=block - ) - output_rows = [ - sorted([dfg[x] for x in self.live_before[succ] if x in dfg]) - for succ in bb.successors - ] - - return CompiledBB(block, bb, Signature(input_row, output_rows)) - - def _check_rows_match(self, row1: VarRow, row2: VarRow, bb: BB) -> None: - """Checks that the types of two rows match up. - - Otherwise, an error is thrown, alerting the user that a variable has different - types on different control-flow paths. - """ - assert len(row1) == len(row2) - for v1, v2 in zip(row1, row2): - assert v1.name == v2.name - if v1.ty != v2.ty: - # In the error message, we want to mention the variable that was first - # defined at the start. - if ( - v1.defined_at - and v2.defined_at - and line_col(v2.defined_at) < line_col(v1.defined_at) - ): - v1, v2 = v2, v1 - # We shouldn't mention temporary variables (starting with `%`) - # in error messages: - ident = ( - "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" - ) - raise GuppyError( - f"{ident} can refer to different types: " - f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", - self.live_before[bb][v1.name].vars.used[v1.name], - [v1.defined_at, v2.defined_at], - ) - - @staticmethod - def _choose_vars_for_pred( - graph: Hugr, pred: OutPortV, output_vars: list[list[str]], dfg: DFContainer - ) -> OutPortV: - """Selects an output based on a predicate. - - Given `pred: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, - constructs a predicate value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. - """ - assert isinstance(pred.ty, SumType) - assert len(pred.ty.element_types) == len(output_vars) - tuples = [ - graph.add_make_tuple( - inputs=[dfg[x].port for x in sorted(vs) if x in dfg], parent=dfg.node - ).out_port(0) - for vs in output_vars - ] - tys = [t.ty for t in tuples] - conditional = graph.add_conditional( - cond_input=pred, inputs=tuples, parent=dfg.node - ) - for i, ty in enumerate(tys): - case = graph.add_case(conditional) - inp = graph.add_input(output_tys=tys, parent=case).out_port(i) - tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) - graph.add_output(inputs=[tag], parent=case) - return conditional.add_out_port(SumType(tys)) From 4390ddd05b486dd5cb15aa9623331e3cd3158f16 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:58:31 +0000 Subject: [PATCH 04/77] Update files --- guppy/ast_util.py | 88 ++++++++++++++++++++++++++++++++++- guppy/error.py | 103 ++++++++++++++++++++++++++++++++++++++--- guppy/guppy_types.py | 107 ++++++++++++++++++++++++++++++++----------- guppy/hugr/hugr.py | 12 ++--- 4 files changed, 269 insertions(+), 41 deletions(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 356a28a7..77603220 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,6 +1,8 @@ import ast -from typing import Any, TypeVar, Generic, Union +from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING +if TYPE_CHECKING: + from guppy.guppy_types import GuppyType AstNode = Union[ ast.AST, @@ -111,8 +113,90 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None: node.end_lineno = loc.end_lineno node.end_col_offset = loc.end_col_offset + source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc) + assert source is not None and file is not None and line_offset is not None + annotate_location(node, source, file, line_offset) -def is_empty_body(func_ast: ast.FunctionDef) -> bool: + +def annotate_location(node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True) -> None: + setattr(node, "line_offset", line_offset) + setattr(node, "file", file) + setattr(node, "source", source) + + if recurse: + for field, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + annotate_location(item, source, file, line_offset, recurse) + elif isinstance(value, ast.AST): + annotate_location(value, source, file, line_offset, recurse) + + +def get_file(node: AstNode) -> Optional[str]: + """Tries to retrieve a file annotation from an AST node.""" + try: + file = getattr(node, "file") + return file if isinstance(file, str) else None + except AttributeError: + return None + + +def get_source(node: AstNode) -> Optional[str]: + """Tries to retrieve a source annotation from an AST node.""" + try: + source = getattr(node, "source") + return source if isinstance(source, str) else None + except AttributeError: + return None + + +def get_line_offset(node: AstNode) -> Optional[int]: + """Tries to retrieve a line offset annotation from an AST node.""" + try: + line_offset = getattr(node, "line_offset") + return line_offset if isinstance(line_offset, int) else None + except AttributeError: + return None + + +A = TypeVar("A", bound=ast.AST) + + +def with_loc(loc: ast.AST, node: A) -> A: + """Copy source location from one AST node to the other.""" + set_location_from(node, loc) + return node + + +def with_type(ty: "GuppyType", node: A) -> A: + """Annotates an AST node with a type.""" + setattr(node, "type", ty) + return node + + +def get_type_opt(node: AstNode) -> Optional["GuppyType"]: + """Tries to retrieve a type annotation from an AST node.""" + from guppy.guppy_types import GuppyType + + try: + ty = getattr(node, "type") + return ty if isinstance(ty, GuppyType) else None + except AttributeError: + return None + + +def get_type(node: AstNode) -> "GuppyType": + """Retrieve a type annotation from an AST node. + + Fails if the node is not annotated. + """ + ty = get_type_opt(node) + assert ty is not None + return ty + + +def has_empty_body(func_ast: ast.FunctionDef) -> bool: """Returns `True` if the body of a function definition is empty. This is the case if the body only contains a single `pass` statement or an ellipsis diff --git a/guppy/error.py b/guppy/error.py index 428c2238..81dfec17 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -1,11 +1,19 @@ +import ast +import functools +import sys +import textwrap from dataclasses import dataclass, field -from typing import Optional, Any, Sequence +from typing import Optional, Any, Sequence, Callable, TypeVar, cast -from guppy.ast_util import AstNode -from guppy.guppy_types import GuppyType +from guppy.ast_util import AstNode, get_line_offset, get_file, get_source +from guppy.guppy_types import GuppyType, FunctionType from guppy.hugr.hugr import OutPortV, Node +# Whether the interpreter should exit when a Guppy error occurs +EXIT_ON_ERROR: bool = True + + @dataclass(frozen=True) class SourceLoc: """A source location associated with an AST node. @@ -14,13 +22,16 @@ class SourceLoc: inside the file. """ + file: str line: int col: int ast_node: Optional[AstNode] @staticmethod - def from_ast(node: AstNode, line_offset: int) -> "SourceLoc": - return SourceLoc(line_offset + node.lineno, node.col_offset, node) + def from_ast(node: AstNode) -> "SourceLoc": + file, line_offset = get_file(node), get_line_offset(node) + assert file is not None and line_offset is not None + return SourceLoc(file, line_offset + node.lineno - 1, node.col_offset, node) def __str__(self) -> str: return f"{self.line}:{self.col}" @@ -43,14 +54,14 @@ class GuppyError(Exception): # The message can also refer to AST locations using format placeholders `{0}`, `{1}` locs_in_msg: Sequence[Optional[AstNode]] = field(default_factory=list) - def get_msg(self, line_offset: int) -> str: + def get_msg(self) -> str: """Returns the message associated with this error. A line offset is needed to translate AST locations mentioned in the message into source locations in the actual file.""" return self.raw_msg.format( *( - SourceLoc.from_ast(loc, line_offset) if loc is not None else "???" + SourceLoc.from_ast(loc) if loc is not None else "???" for loc in self.locs_in_msg ) ) @@ -88,3 +99,81 @@ def node(self) -> Node: @property def offset(self) -> int: raise InternalGuppyError("Tried to access undefined Port") + + +class UnknownFunctionType(FunctionType): + """Dummy function type for custom functions without an expressible type. + + Raises an `InternalGuppyError` if one tries to access one of its members. + """ + + def __init__(self) -> None: + pass + + @property + def args(self) -> Sequence[GuppyType]: + raise InternalGuppyError("Tried to access unknown function type") + + @property + def returns(self) -> GuppyType: + raise InternalGuppyError("Tried to access unknown function type") + + @property + def args_names(self) -> Optional[Sequence[str]]: + raise InternalGuppyError("Tried to access unknown function type") + + +def format_source_location( + loc: ast.AST, + num_lines: int = 3, + indent: int = 4, +) -> str: + """Creates a pretty banner to show source locations for errors.""" + source, line_offset = get_source(loc), get_line_offset(loc) + assert source is not None and line_offset is not None + source_lines = source.splitlines(keepends=True) + end_col_offset = loc.end_col_offset or len(source_lines[loc.lineno]) + s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() + s += "\n" + loc.col_offset * " " + (end_col_offset - loc.col_offset) * "^" + s = textwrap.dedent(s).splitlines() + # Add line numbers + line_numbers = [ + str(line_offset + loc.lineno - i) + ":" for i in range(num_lines, 0, -1) + ] + longest = max(len(ln) for ln in line_numbers) + prefixes = [ln + " " * (longest - len(ln) + indent) for ln in line_numbers] + res = "".join(prefix + line + "\n" for prefix, line in zip(prefixes, s[:-1])) + res += (longest + indent) * " " + s[-1] + return res + + +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + + +def pretty_errors(f: FuncT) -> FuncT: + """Decorator to print custom error banners when a `GuppyError` occurs.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + try: + return f(*args, **kwargs) + except GuppyError as err: + # Reraise if we're missing a location + if not err.location: + raise err + loc = err.location + file, line_offset = get_file(loc), get_line_offset(loc) + assert file is not None and line_offset is not None + line = line_offset + loc.lineno - 1 + print( + f"Guppy compilation failed. Error in file {file}:{line}\n\n" + f"{format_source_location(loc)}\n" + f"{err.__class__.__name__}: {err.get_msg()}", + file=sys.stderr, + ) + if EXIT_ON_ERROR: + sys.exit(1) + return None + + return cast(FuncT, wrapped) + diff --git a/guppy/guppy_types.py b/guppy/guppy_types.py index 0658781e..95c86f1c 100644 --- a/guppy/guppy_types.py +++ b/guppy/guppy_types.py @@ -7,7 +7,7 @@ from guppy.ast_util import AstNode, set_location_from if TYPE_CHECKING: - from guppy.compiler_base import Globals + from guppy.checker.core import Globals class GuppyType(ABC): @@ -20,7 +20,7 @@ class GuppyType(ABC): @staticmethod @abstractmethod - def build(*args: "GuppyType", node: Union[ast.Name, ast.Subscript]) -> "GuppyType": + def build(*args: "GuppyType", node: AstNode) -> "GuppyType": pass @property @@ -33,23 +33,10 @@ def to_hugr(self) -> tys.SimpleType: pass -@dataclass(frozen=True) -class TypeRow: - tys: Sequence[GuppyType] - - def __str__(self) -> str: - if len(self.tys) == 0: - return "None" - elif len(self.tys) == 1: - return str(self.tys[0]) - else: - return f"({', '.join(str(e) for e in self.tys)})" - - @dataclass(frozen=True) class FunctionType(GuppyType): args: Sequence[GuppyType] - returns: Sequence[GuppyType] + returns: GuppyType arg_names: Optional[Sequence[str]] = field( default=None, compare=False, # Argument names are not taken into account for type equality @@ -59,17 +46,21 @@ class FunctionType(GuppyType): linear = False def __str__(self) -> str: - return f"{TypeRow(self.args)} -> {TypeRow(self.returns)}" + if len(self.args) == 1: + [arg] = self.args + return f"{arg} -> {self.returns}" + else: + return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: + def build(*args: GuppyType, node: AstNode) -> GuppyType: # Function types cannot be constructed using `build`. The type parsing code # has a special case for function types. raise NotImplementedError() def to_hugr(self) -> tys.SimpleType: ins = [t.to_hugr() for t in self.args] - outs = [t.to_hugr() for t in self.returns] + outs = [t.to_hugr() for t in type_to_row(self.returns)] return tys.FunctionType(input=ins, output=outs, extension_reqs=[]) @@ -80,7 +71,7 @@ class TupleType(GuppyType): name: str = "tuple" @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: + def build(*args: GuppyType, node: AstNode) -> GuppyType: return TupleType(list(args)) def __str__(self) -> str: @@ -100,7 +91,7 @@ class SumType(GuppyType): element_types: Sequence[GuppyType] @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: + def build(*args: GuppyType, node: AstNode) -> GuppyType: # Sum types cannot be parsed and constructed using `build` since they cannot be # written by the user raise NotImplementedError() @@ -121,9 +112,54 @@ def to_hugr(self) -> tys.SimpleType: return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) +@dataclass(frozen=True) +class NoneType(GuppyType): + name: str = "None" + linear: bool = False + + @staticmethod + def build(*args: GuppyType, node: AstNode) -> GuppyType: + if len(args) > 0: + from guppy.error import GuppyError + + raise GuppyError("Type `None` is not generic", node) + return NoneType() + + def __str__(self) -> str: + return "None" + + def to_hugr(self) -> tys.SimpleType: + return tys.Tuple(inner=[]) + + +@dataclass(frozen=True) +class BoolType(SumType): + """The type of booleans.""" + + linear = False + name = "bool" + + def __init__(self) -> None: + # Hugr bools are encoded as Sum((), ()) + super().__init__([TupleType([]), TupleType([])]) + + @staticmethod + def build(*args: GuppyType, node: AstNode) -> GuppyType: + if len(args) > 0: + from guppy.error import GuppyError + + raise GuppyError("Type `bool` is not generic", node) + return BoolType() + + def __str__(self) -> str: + return "bool" + + def _lookup_type(node: AstNode, globals: "Globals") -> Optional[type[GuppyType]]: if isinstance(node, ast.Name) and node.id in globals.types: return globals.types[node.id] + if isinstance(node, ast.Constant) and node.value is None: + return NoneType if ( isinstance(node, ast.Constant) and isinstance(node.value, str) @@ -157,23 +193,42 @@ def type_from_ast(node: AstNode, globals: "Globals") -> GuppyType: if isinstance(func_args, ast.List): return FunctionType( [type_from_ast(a, globals) for a in func_args.elts], - type_row_from_ast(ret, globals).tys, + type_from_ast(ret, globals), ) from guppy.error import GuppyError raise GuppyError("Not a valid Guppy type", node) -def type_row_from_ast(node: ast.expr, globals: "Globals") -> TypeRow: +def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[GuppyType]: """Turns an AST expression into a Guppy type row. This is needed to interpret the return type annotation of functions. """ # The return type `-> None` is represented in the ast as `ast.Constant(value=None)` if isinstance(node, ast.Constant) and node.value is None: - return TypeRow([]) + return [] ty = type_from_ast(node, globals) if isinstance(ty, TupleType): - return TypeRow(ty.element_types) + return ty.element_types + else: + return [ty] + + +def row_to_type(row: Sequence[GuppyType]) -> GuppyType: + """Turns a row of types into a single type by packing into a tuple.""" + if len(row) == 0: + return NoneType() + elif len(row) == 1: + return row[0] else: - return TypeRow([ty]) + return TupleType(row) + + +def type_to_row(ty: GuppyType) -> Sequence[GuppyType]: + """Turns a type into a row of types by unpacking top-level tuples.""" + if isinstance(ty, NoneType): + return [] + if isinstance(ty, TupleType): + return ty.element_types + return [ty] diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 8baf350d..71632980 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Optional, Iterator, Tuple, Any +from typing import Optional, Iterator, Tuple, Any, Union from dataclasses import field, dataclass import guppy.hugr.ops as ops @@ -12,9 +12,9 @@ GuppyType, TupleType, FunctionType, - SumType, + SumType, type_to_row, row_to_type, ) -from guppy.hugr import val +from guppy.hugr import val, tys NodeIdx = int PortOffset = int @@ -483,7 +483,7 @@ def add_call( """Adds a `Call` node to the graph.""" assert isinstance(def_port.ty, FunctionType) return self.add_node( - ops.Call(), None, list(def_port.ty.returns), parent, args + [def_port] + ops.Call(), None, list(type_to_row(def_port.ty.returns)), parent, args + [def_port] ) def add_indirect_call( @@ -494,7 +494,7 @@ def add_indirect_call( return self.add_node( ops.CallIndirect(), None, - list(fun_port.ty.returns), + list(type_to_row(fun_port.ty.returns)), parent, [fun_port] + args, ) @@ -628,7 +628,7 @@ def remove_dummy_nodes(self) -> "Hugr": for n in list(self.nodes()): if isinstance(n, VNode) and isinstance(n.op, ops.DummyOp): name = n.op.name - fun_ty = FunctionType(list(n.in_port_types), list(n.out_port_types)) + fun_ty = FunctionType(list(n.in_port_types), row_to_type(n.out_port_types)) if name in used_names: used_names[name] += 1 name = f"{name}${used_names[name]}" From d48a14935c5dece30060dddd52bba0cc76644398 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:58:42 +0000 Subject: [PATCH 05/77] Add statement compiler --- guppy/compiler/stmt_compiler.py | 91 +++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 guppy/compiler/stmt_compiler.py diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py new file mode 100644 index 00000000..9ed45c65 --- /dev/null +++ b/guppy/compiler/stmt_compiler.py @@ -0,0 +1,91 @@ +import ast +from typing import Sequence + +from guppy.ast_util import AstVisitor, get_type +from guppy.checker.cfg_checker import CheckedBB +from guppy.compiler.core import CompilerBase, DFContainer, CompiledGlobals, \ + PortVariable, return_var +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.error import InternalGuppyError +from guppy.guppy_types import GuppyType, NoneType, TupleType +from guppy.hugr.hugr import Hugr, OutPortV +from guppy.nodes import CheckedNestedFunctionDef + + +class StmtCompiler(CompilerBase, AstVisitor[None]): + """A compiler for Guppy statements to Hugr""" + + expr_compiler: ExprCompiler + + bb: CheckedBB + dfg: DFContainer + + def __init__(self, graph: Hugr, globals: CompiledGlobals): + super().__init__(graph, globals) + self.expr_compiler = ExprCompiler(graph, globals) + + def compile_stmts( + self, + stmts: Sequence[ast.stmt], + bb: CheckedBB, + dfg: DFContainer, + ) -> DFContainer: + """Compiles a list of basic statements into a dataflow node. + + Note that the `dfg` is mutated in-place. After compilation, the DFG will also + contain all variables that are assigned in the given list of statements. + """ + self.bb = bb + self.dfg = dfg + for s in stmts: + self.visit(s) + return self.dfg + + def _unpack_assign(self, lhs: ast.expr, port: OutPortV, node: ast.stmt) -> None: + """Updates the local DFG with assignments.""" + if isinstance(lhs, ast.Name): + x = lhs.id + self.dfg[x] = PortVariable(x, port, node) + elif isinstance(lhs, ast.Tuple): + unpack = self.graph.add_unpack_tuple(port, self.dfg.node) + for i, pat in enumerate(lhs.elts): + self._unpack_assign(pat, unpack.out_port(i), node) + else: + raise InternalGuppyError("Invalid assign pattern in compiler") + + def visit_Assign(self, node: ast.Assign) -> None: + [target] = node.targets + port = self.expr_compiler.compile(node.value, self.dfg) + self._unpack_assign(target, port, node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + assert node.value is not None + port = self.expr_compiler.compile(node.value, self.dfg) + self._unpack_assign(node.target, port, node) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Expr(self, node: ast.Expr) -> None: + self.expr_compiler.compile_row(node.value, self.dfg) + + def visit_Return(self, node: ast.Return) -> None: + # We turn returns into assignments of dummy variables, i.e. the statement + # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`. + if node.value is not None: + port = self.expr_compiler.compile(node.value, self.dfg) + if isinstance(port.ty, TupleType): + unpack = self.graph.add_unpack_tuple(port, self.dfg.node) + row = [unpack.out_port(i) for i in range(len(port.ty.element_types))] + else: + row = [port] + for i, port in enumerate(row): + name = return_var(i) + self.dfg[name] = PortVariable(name, port, node) + + def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: + from guppy.compiler.func_compiler import compile_local_func_def + + self.dfg[node.name] = compile_local_func_def(node, self.dfg, self.graph, self.globals) + + From a5d51e130f26bc6c62239019978ac61dea7b5bb7 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:59:07 +0000 Subject: [PATCH 06/77] Moves custom AST nodes to separate file --- guppy/nodes.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 guppy/nodes.py diff --git a/guppy/nodes.py b/guppy/nodes.py new file mode 100644 index 00000000..a84f9de0 --- /dev/null +++ b/guppy/nodes.py @@ -0,0 +1,73 @@ +"""Custom AST nodes used by Guppy""" + +import ast +from typing import TYPE_CHECKING, Any, Mapping + +from guppy.guppy_types import FunctionType + +if TYPE_CHECKING: + from guppy.cfg.cfg import CFG + from guppy.checker.core import Variable, CallableVariable + from guppy.checker.cfg_checker import CheckedCFG + + +class LocalName(ast.expr): + id: str + + _fields = ( + 'id', + ) + + +class GlobalName(ast.expr): + id: str + value: "Variable" + + _fields = ( + 'id', + 'value', + ) + + +class LocalCall(ast.expr): + func: ast.expr + args: list[ast.expr] + + _fields = ( + 'func', + 'args', + ) + + +class GlobalCall(ast.expr): + func: "CallableVariable" + args: list[ast.expr] + + # Later: Inferred type args + + _fields = ( + 'func', + 'args', + ) + + +class NestedFunctionDef(ast.FunctionDef): + cfg: "CFG" + ty: FunctionType + + def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.cfg = cfg + self.ty = ty + + +class CheckedNestedFunctionDef(ast.FunctionDef): + cfg: "CheckedCFG" + ty: FunctionType + captured: Mapping[str, "Variable"] + + def __init__(self, cfg: "CheckedCFG", ty: FunctionType, captured: Mapping[str, "Variable"], *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.cfg = cfg + self.ty = ty + self.captured = captured From 3edaa00d4000025570d9a3d912d0d8f21e6a953f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 15:59:36 +0000 Subject: [PATCH 07/77] Add custom funcs, replacing extensions --- guppy/custom.py | 231 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 guppy/custom.py diff --git a/guppy/custom.py b/guppy/custom.py new file mode 100644 index 00000000..d267ae07 --- /dev/null +++ b/guppy/custom.py @@ -0,0 +1,231 @@ +import ast +from abc import ABC, abstractmethod +from typing import Optional + +from guppy.ast_util import AstNode, get_type, with_type, with_loc +from guppy.checker.core import CallableVariable, Context, Globals +from guppy.checker.expr_checker import check_call, synthesize_call +from guppy.checker.func_checker import check_signature +from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.error import GuppyError, InternalGuppyError, UnknownFunctionType, \ + GuppyTypeError +from guppy.guppy_types import GuppyType, FunctionType, type_to_row, TupleType +from guppy.hugr import ops +from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode +from guppy.nodes import GlobalCall + + +class CustomFunction(CompiledFunction): + """A function whose type checking and compilation behaviour can be customised.""" + + defined_at: Optional[ast.FunctionDef] + + # Whether the function may be used as a higher-order value. This is only possible + # if a static type for the function is provided. + higher_order_value: bool + + call_checker: "CustomCallChecker" + call_compiler: "CustomCallCompiler" + + _ty: Optional[FunctionType] = None + _defined: dict[Node, DFContainingVNode] = {} + + def __init__(self, name: str, defined_at: Optional[ast.FunctionDef], compiler: "CustomCallCompiler", checker: "CustomCallChecker", higher_order_value: bool = True, ty: Optional[FunctionType] = None): + self.name = name + self.defined_at = defined_at + self.higher_order_value = higher_order_value + self.call_compiler = compiler + self.call_checker = checker + self.used = None + self._ty = ty + self._defined = {} + + @property # type: ignore + def ty(self) -> FunctionType: + if self._ty is None: + return UnknownFunctionType() + return self._ty + + @ty.setter + def ty(self, ty: FunctionType) -> None: + self._ty = ty + + def check_type(self, globals: Globals) -> None: + """Checks the type annotation on the signature declaration if provided.""" + if self._ty is not None: + return + + if self.defined_at is None: + if self.higher_order_value: + raise GuppyError( + f"Type signature for function `{self.name}` is required. " + "Alternatively, try passing `higher_order_value=False` on " + "definition." + ) + return + + try: + self._ty = check_signature(self.defined_at, globals) + except GuppyError as err: + # We can ignore the error if a custom call checker is provided and the + # function may not be used as a higher-order value + if self.call_checker is None or self.higher_order_value: + raise err + + def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> ast.expr: + self.call_checker._setup(ctx, node, self) + return with_type(ty, with_loc(node, self.call_checker.check(args, ty))) + + def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]: + self.call_checker._setup(ctx, node, self) + new_node, ty = self.call_checker.synthesize(args) + return with_type(ty, with_loc(node, new_node)), ty + + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: + self.call_compiler._setup(dfg, graph, globals, node) + return self.call_compiler.compile(args) + + def load(self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + """Loads the custom function as a value into a local dataflow graph. + + This will place a `FunctionDef` node into the Hugr module if one for this + function doesn't already exist and loads it into the DFG. This operation will + fail if no function type has been specified. + """ + if self._ty is None: + raise GuppyError( + "This function does not support usage in a higher-order context", + node, + ) + + # Find the module node by walking up the hierarchy + module: Node = dfg.node + while not isinstance(module.op, ops.Module): + if module.parent is None: + raise InternalGuppyError( + "Encountered node that is not contained in a module." + ) + module = module.parent + + # If the function has not yet been loaded in this module, we first have to + # define it. We create a `FunctionDef` that takes some inputs, compiles a call + # to the function, and returns the results. + if module not in self._defined: + def_node = graph.add_def(self.ty, module, self.name) + inp = graph.add_input(list(self.ty.args), parent=def_node) + returns = self.compile_call( + [inp.out_port(i) for i in range(len(self.ty.args))], + DFContainer(def_node, {}), + graph, + globals, + node, + ) + graph.add_output(returns, parent=def_node) + self._defined[module] = def_node + + # Finally, load the function into the local DFG + return graph.add_load_constant( + self._defined[module].out_port(0), dfg.node + ).out_port(0) + + +class CustomCallChecker(ABC): + """Protocol for custom function call type checkers.""" + + ctx: Context + node: AstNode + func: CustomFunction + + def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: + self.ctx = ctx + self.node = node + self.func = func + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + """Checks the return value against a given type. + + Returns a (possibly) transformed and annotated AST node for the call. + """ + args, actual = self.synthesize(args) + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `{actual}`", self.node + ) + + @abstractmethod + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + """Synthesizes a type for the return value of a call. + + Also returns a (possibly) transformed and annotated argument list. + """ + + +class CustomCallCompiler(ABC): + """Protocol for custom function call compilers.""" + + dfg: DFContainer + graph: Hugr + globals: CompiledGlobals + node: AstNode + + def _setup(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> None: + self.dfg = dfg + self.graph = graph + self.globals = globals + self.node = node + + @abstractmethod + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + """Compiles a custom function call and returns the resulting ports.""" + + +class DefaultCallChecker(CustomCallChecker): + """Checks function calls by comparing to a type signature.""" + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + # Use default implementation from the expression checker + args = check_call(self.func.ty, args, ty, self.node, self.ctx) + return GlobalCall(func=self.func, args=args) + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + # Use default implementation from the expression checker + args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) + return GlobalCall(func=self.func, args=args), ty + + +class DefaultCallCompiler(CustomCallCompiler): + """Call compiler that invokes the regular expression compiler.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + assert isinstance(self.node, ast.expr) + ret = ExprCompiler(self.graph, self.globals).compile(self.node, self.dfg) + if isinstance(ret.ty, TupleType): + unpack = self.graph.add_unpack_tuple(ret, self.dfg.node) + return [unpack.out_port(i) for i in range(len(ret.ty.element_types))] + return [ret] + + +class OpCompiler(CustomCallCompiler): + op: ops.OpType + + def __init__(self, op: ops.OpType) -> None: + self.op = op + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) + return_ty = get_type(self.node) + assert return_ty is not None + return [node.add_out_port(ty) for ty in type_to_row(return_ty)] + + +class NoopCompiler(CustomCallCompiler): + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + return args From de2508fe73582e8628cbc00d01afa7f2e0888f0e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:00:03 +0000 Subject: [PATCH 08/77] Factor out decorator from module code --- guppy/declared.py | 49 +++++++++++ guppy/decorator.py | 161 ++++++++++++++++++++++++++++++++++++ guppy/module.py | 201 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 411 insertions(+) create mode 100644 guppy/declared.py create mode 100644 guppy/decorator.py create mode 100644 guppy/module.py diff --git a/guppy/declared.py b/guppy/declared.py new file mode 100644 index 00000000..167f9abd --- /dev/null +++ b/guppy/declared.py @@ -0,0 +1,49 @@ +import ast +from dataclasses import dataclass +from typing import Optional + +from guppy.ast_util import AstNode, has_empty_body +from guppy.checker.core import Globals, Context +from guppy.checker.expr_checker import check_call, synthesize_call +from guppy.checker.func_checker import check_signature +from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals +from guppy.error import GuppyError +from guppy.guppy_types import type_to_row, GuppyType +from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV +from guppy.nodes import GlobalCall + + +@dataclass +class DeclaredFunction(CompiledFunction): + """A user-declared function that compiles to a Hugr function declaration.""" + + node: Optional[VNode] = None + + @staticmethod + def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DeclaredFunction": + ty = check_signature(func_def, globals) + if not has_empty_body(func_def): + raise GuppyError("Body of function declaration must be empty", func_def.body[0]) + return DeclaredFunction(name, ty, func_def, None) + + def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall: + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) + + def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]: + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty + + def add_to_graph(self, graph: Hugr, parent: Node) -> None: + self.node = graph.add_declare(self.ty, parent, self.name) + + def load(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + assert self.node is not None + return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) + + def compile_call(self, args: list[OutPortV], dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> list[OutPortV]: + assert self.node is not None + call = graph.add_call(self.node.out_port(0), args, dfg.node) + return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/decorator.py b/guppy/decorator.py new file mode 100644 index 00000000..b61c06ac --- /dev/null +++ b/guppy/decorator.py @@ -0,0 +1,161 @@ +import ast +import functools +from dataclasses import dataclass +from typing import Optional, Union, Callable, Any + +from guppy.ast_util import AstNode, has_empty_body +from guppy.checker.func_checker import check_signature +from guppy.custom import CustomFunction, OpCompiler, DefaultCallChecker, \ + CustomCallCompiler, CustomCallChecker, DefaultCallCompiler +from guppy.error import GuppyError, pretty_errors +from guppy.guppy_types import GuppyType +from guppy.hugr import tys, ops +from guppy.hugr.hugr import Hugr +from guppy.module import GuppyModule, PyFunc, parse_py_func + +FuncDecorator = Callable[[PyFunc], PyFunc] +CustomFuncDecorator = Callable[[PyFunc], CustomFunction] +ClassDecorator = Callable[[type], type] + + +class _Guppy: + """Class for the `@guppy` decorator.""" + + # The current module + _module: Optional[GuppyModule] + + def __init__(self) -> None: + self._module = None + + def set_module(self, module: GuppyModule) -> None: + self._module = module + + @pretty_errors + def __call__(self, arg: Union[PyFunc, GuppyModule]) -> Union[Optional[Hugr], FuncDecorator]: + """Decorator to annotate Python functions as Guppy code. + + Optionally, the `GuppyModule` in which the function should be placed can be passed + to the decorator. + """ + if isinstance(arg, GuppyModule): + def dec(f: Callable[..., Any]) -> Callable[..., Any]: + assert isinstance(arg, GuppyModule) + arg.register_func_def(f) + + @functools.wraps(f) + def dummy(*args: Any, **kwargs: Any) -> Any: + raise GuppyError( + "Guppy functions can only be called in a Guppy context" + ) + + return dummy + + return dec + else: + module = self._module or GuppyModule("module") + module.register_func_def(arg) + return module.compile() + + @pretty_errors + def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator: + """Decorator to add new instance functions to a type.""" + module._instance_func_buffer = {} + + def dec(c: type) -> type: + module._register_buffered_instance_funcs(ty) + return c + + return dec + + @pretty_errors + def type(self, module: GuppyModule, hugr_ty: tys.SimpleType, name: str = "", linear: bool = False) -> ClassDecorator: + """Decorator to annotate a class definitions as Guppy types. + + Requires the static Hugr translation of the type. Additionally, the type can be + marked as linear. All `@guppy` annotated functions on the class are turned into + instance functions. + """ + module._instance_func_buffer = {} + + def dec(c: type) -> type: + _name = name or c.__name__ + + @dataclass(frozen=True) + class NewType(GuppyType): + name = _name + + @staticmethod + def build(*args: GuppyType, node: AstNode) -> "GuppyType": + # At the moment, custom types don't support type arguments. + if len(args) > 0: + raise GuppyError( + f"Type `{_name}` does not accept type parameters.", node + ) + return NewType() + + @property + def linear(self) -> bool: + return linear + + def to_hugr(self) -> tys.SimpleType: + return hugr_ty + + def __str__(self) -> str: + return _name + + NewType.__name__ = name + NewType.__qualname__ = _name + module.register_type(_name, NewType) + module._register_buffered_instance_funcs(NewType) + setattr(c, "_guppy_type", NewType) + return c + + return dec + + @pretty_errors + def custom(self, module: GuppyModule, compiler: Optional[CustomCallCompiler] = None, checker: Optional[CustomCallChecker] = None, higher_order_value: bool = True, name: str = "") -> CustomFuncDecorator: + """Decorator to add custom typing or compilation behaviour to function decls. + + Optionally, usage of the function as a higher-order value can be disabled. In + that case, the function signature can be omitted if a custom call compiler is + provided. + """ + + def dec(f: PyFunc) -> CustomFunction: + func_ast = parse_py_func(f) + if not has_empty_body(func_ast): + raise GuppyError( + "Body of custom function declaration must be empty", + func_ast.body[0] + ) + call_checker = checker or DefaultCallChecker() + func = CustomFunction(name or func_ast.name, func_ast, compiler or DefaultCallCompiler(), call_checker, higher_order_value) + call_checker.func = func + module.register_custom_func(func) + return func + + return dec + + def hugr_op(self, module: GuppyModule, op: ops.OpType, checker: Optional[CustomCallChecker] = None, higher_order_value: bool = True, name: str = "") -> CustomFuncDecorator: + """Decorator to annotate function declarations as HUGR ops.""" + return self.custom(module, OpCompiler(op), checker, higher_order_value, name) + + def declare(self, module: GuppyModule) -> FuncDecorator: + """Decorator to declare functions""" + def dec(f: Callable[..., Any]) -> Callable[..., Any]: + module.register_func_decl(f) + + @functools.wraps(f) + def dummy(*args: Any, **kwargs: Any) -> Any: + raise GuppyError( + "Guppy functions can only be called in a Guppy context" + ) + + return dummy + + return dec + + + + +guppy = _Guppy() diff --git a/guppy/module.py b/guppy/module.py new file mode 100644 index 00000000..79ddeecf --- /dev/null +++ b/guppy/module.py @@ -0,0 +1,201 @@ +import ast +import inspect +import textwrap +from types import ModuleType + +from typing import Callable, Any, Optional, Union + +from guppy.ast_util import annotate_location, AstNode +from guppy.checker.core import Globals, qualified_name +from guppy.checker.func_checker import DefinedFunction, check_global_func_def +from guppy.compiler.core import CompiledGlobals, CompiledFunction, DFContainer +from guppy.compiler.func_compiler import compile_global_func_def, CompiledFunctionDef +from guppy.custom import CustomFunction +from guppy.declared import DeclaredFunction +from guppy.error import GuppyError, pretty_errors +from guppy.guppy_types import GuppyType, type_to_row +from guppy.hugr.hugr import Hugr, Node, VNode, OutPortV + +PyFunc = Callable[..., Any] + + +class GuppyModule: + """A Guppy module that may contain function and type definitions.""" + + name: str + + # Whether the module has already been compiled + _compiled: bool + + # Globals from imported modules + _imported_globals: Globals + _imported_compiled_globals: CompiledGlobals + + # Globals for functions and types defined in this module. Only gets populated during + # compilation + _globals: Globals + _compiled_globals: CompiledGlobals + + # Mappings of functions defined in this module + _func_defs: dict[str, ast.FunctionDef] + _func_decls: dict[str, ast.FunctionDef] + _custom_funcs: dict[str, CustomFunction] + + # When `_instance_buffer` is not `None`, then all registered functions will be + # buffered in this list. They only get properly registered, once + # `_register_buffered_instance_funcs` is called. This way, we can associate + _instance_func_buffer: Optional[dict[str, Union[PyFunc, CustomFunction]]] + + def __init__(self, name: str, import_builtins: bool = True): + self.name = name + self._globals = Globals({}, {}) + self._compiled_globals = {} + self._imported_globals = Globals.default() + self._imported_compiled_globals = {} + self._func_defs = {} + self._func_decls = {} + self._custom_funcs = {} + self._compiled = False + self._instance_func_buffer = None + + # Import builtin module + if import_builtins: + import guppy.prelude.builtins as builtins + self.load(builtins) + + def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: + """Imports another Guppy module.""" + self._check_not_yet_compiled() + if isinstance(m, GuppyModule): + # Compile module if it isn't compiled yet + if not m.compiled: + m.compile() + + # For now, we can only import custom functions + if any(not isinstance(v, CustomFunction) for v in m._compiled_globals.values()): + raise GuppyError( + "Importing modules with defined functions is not supported yet" + ) + + self._imported_globals |= m._globals + self._imported_compiled_globals |= m._compiled_globals + else: + for val in m.__dict__.values(): + if isinstance(val, GuppyModule): + self.load(val) + + def register_func_def(self, f: PyFunc, instance: Optional[type[GuppyType]] = None) -> None: + """Registers a Python function definition as belonging to this Guppy module.""" + self._check_not_yet_compiled() + func_ast = parse_py_func(f) + if self._instance_func_buffer is not None: + self._instance_func_buffer[func_ast.name] = f + else: + name = qualified_name(instance, func_ast.name) if instance else func_ast.name + self._check_name_available(name, func_ast) + self._func_defs[name] = func_ast + + def register_func_decl(self, f: PyFunc) -> None: + """Registers a Python function declaration as belonging to this Guppy module.""" + self._check_not_yet_compiled() + func_ast = parse_py_func(f) + self._check_name_available(func_ast.name, func_ast) + self._func_decls[func_ast.name] = func_ast + + def register_custom_func(self, func: CustomFunction, instance: Optional[type[GuppyType]] = None) -> None: + """Registers a custom function as belonging to this Guppy module.""" + self._check_not_yet_compiled() + if self._instance_func_buffer is not None: + self._instance_func_buffer[func.name] = func + else: + if instance: + func.name = qualified_name(instance, func.name) + self._check_name_available(func.name, func.defined_at) + self._custom_funcs[func.name] = func + + def register_type(self, name: str, ty: type[GuppyType]) -> None: + """Registers an existing Guppy type as belonging to this Guppy module.""" + self._globals.types[name] = ty + + def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: + assert self._instance_func_buffer is not None + buffer = self._instance_func_buffer + self._instance_func_buffer = None + for name, f in buffer.items(): + if isinstance(f, CustomFunction): + self.register_custom_func(f, instance) + else: + self.register_func_def(f, instance) + + @property + def compiled(self) -> bool: + return self._compiled + + @pretty_errors + def compile(self) -> Optional[Hugr]: + """Compiles the module and returns the final Hugr.""" + if self.compiled: + raise GuppyError("Module has already been compiled") + + # Prepare globals for type checking + for func in self._custom_funcs.values(): + func.check_type(self._imported_globals | self._globals) + defined_funcs = { + x: DefinedFunction.from_ast(f, x, self._imported_globals | self._globals) + for x, f in self._func_defs.items() + } + declared_funcs = { + x: DeclaredFunction.from_ast(f, x, self._imported_globals | self._globals) + for x, f in self._func_decls.items() + } + self._globals.values.update(self._custom_funcs) + self._globals.values.update(declared_funcs) + self._globals.values.update(defined_funcs) + + # Type check function definitions + checked = {x: check_global_func_def(f, self._imported_globals | self._globals) for x, f in defined_funcs.items()} + + # Add declared functions to the graph + graph = Hugr(self.name) + module_node = graph.set_root_name(self.name) + for f in declared_funcs.values(): + f.add_to_graph(graph, module_node) + + # Prepare `FunctionDef` nodes for all function definitions + def_nodes = {x: graph.add_def(f.ty, module_node, x) for x, f in checked.items()} + self._compiled_globals |= self._custom_funcs | declared_funcs | { + x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) + for x, f in checked.items() + } + + # Compile function definitions to Hugr + for x, f in checked.items(): + compile_global_func_def(f, def_nodes[x], graph, self._imported_compiled_globals | self._compiled_globals) + + self._compiled = True + return graph + + def _check_not_yet_compiled(self) -> None: + if self._compiled: + raise GuppyError(f"The module `{self.name}` has already been compiled") + + def _check_name_available(self, name: str, node: Optional[AstNode]) -> None: + if name in self._func_defs or name in self._custom_funcs: + raise GuppyError( + f"Module `{self.name}` already contains a function named `{name}`", + node, + ) + + +def parse_py_func(f: PyFunc) -> ast.FunctionDef: + source_lines, line_offset = inspect.getsourcelines(f) + source = "".join(source_lines) # Lines already have trailing \n's + source = textwrap.dedent(source) + func_ast = ast.parse(source).body[0] + file = inspect.getsourcefile(f) + if file is None: + raise GuppyError("Couldn't determine source file for function") + annotate_location(func_ast, source, file, line_offset) + if not isinstance(func_ast, ast.FunctionDef): + raise GuppyError("Expected a function definition", func_ast) + return func_ast From f9033fde946bfb2dd12d2b2c96f7c18ea4706f5c Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:00:39 +0000 Subject: [PATCH 09/77] Reimplement prelude --- guppy/prelude/_internal.py | 263 +++++++++++++++++++++++++ guppy/prelude/builtins.py | 385 +++++++++++++++++++++++++++++++++++++ guppy/prelude/quantum.py | 37 ++-- 3 files changed, 665 insertions(+), 20 deletions(-) create mode 100644 guppy/prelude/_internal.py create mode 100644 guppy/prelude/builtins.py diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py new file mode 100644 index 00000000..7c410f7c --- /dev/null +++ b/guppy/prelude/_internal.py @@ -0,0 +1,263 @@ +import ast +from typing import Optional, Literal + +from pydantic import BaseModel + +from guppy.ast_util import with_type, AstNode, with_loc, get_type +from guppy.checker.core import Context, CallableVariable +from guppy.checker.expr_checker import ExprSynthesizer, check_num_args +from guppy.custom import CustomCallChecker, DefaultCallChecker, CustomFunction, \ + CustomCallCompiler +from guppy.error import GuppyTypeError +from guppy.guppy_types import GuppyType, type_to_row, FunctionType, BoolType +from guppy.hugr import ops, tys, val +from guppy.hugr.hugr import OutPortV +from guppy.nodes import GlobalCall + + +INT_WIDTH = 6 # 2^6 = 64 bit + + +hugr_int_type = tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.BoundedNatArg(n=INT_WIDTH)], + bound=tys.TypeBound.Eq, +) + + +hugr_float_type = tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, +) + + +class ConstIntS(BaseModel): + """Hugr representation of signed integers in the arithmetic extension.""" + + c: Literal["ConstIntS"] = "ConstIntS" + log_width: int + value: int + + +class ConstF64(BaseModel): + """Hugr representation of floats in the arithmetic extension.""" + + c: Literal["ConstF64"] = "ConstF64" + value: float + + +def bool_value(b: bool) -> val.Value: + """Returns the Hugr representation of a boolean value.""" + return val.Sum(tag=int(b), value=val.Tuple(vs=[])) + + +def int_value(i: int) -> val.Value: + """Returns the Hugr representation of an integer value.""" + return val.Prim(val=val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),))) + + +def float_value(f: float) -> val.Value: + """Returns the Hugr representation of a float value.""" + return val.Prim(val=val.ExtensionVal(c=(ConstF64(value=f),))) + + +def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops.OpType: + """Utility method to create Hugr logic ops.""" + return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) + + +def int_op(op_name: str, ext: str = "arithmetic.int", num_params: int = 1) -> ops.OpType: + """Utility method to create Hugr integer arithmetic ops.""" + return ops.CustomOp( + extension=ext, + op_name=op_name, + args=num_params * [tys.BoundedNatArg(n=INT_WIDTH)], + ) + + +def float_op(op_name: str, ext: str = "arithmetic.float") -> ops.OpType: + """Utility method to create Hugr integer arithmetic ops.""" + return ops.CustomOp(extension=ext, op_name=op_name, args=[]) + + +class CoercingChecker(DefaultCallChecker): + """Function call type checker thag automatically coerces arguments to float.""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + from .builtins import Int + + for i in range(len(args)): + args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) + if isinstance(ty, self.ctx.globals.types["int"]): + call = with_loc(self.node, GlobalCall(func=Int.__float__, args=[args[i]])) + args[i] = with_type(self.ctx.globals.types["float"].build(node=self.node), call) + return super().synthesize(args) + + +class ReversingChecker(CustomCallChecker): + """Call checker that reverses the arguments after checking.""" + + base_checker: CustomCallChecker + + def __init__(self, base_checker: CustomCallChecker): + self.base_checker = base_checker + + def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: + super()._setup(ctx, node, func) + self.base_checker._setup(ctx, node, func) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + expr = self.base_checker.check(args, ty) + if isinstance(expr, GlobalCall): + expr.args = list(reversed(args)) + return expr + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + expr, ty = self.base_checker.synthesize(args) + if isinstance(expr, GlobalCall): + expr.args = list(reversed(args)) + return expr, ty + + +class NotImplementedCompiler(CustomCallCompiler): + """Compiler for functions that are not yet implemented.""" + + name: str + + def __init__(self, name: str) -> None: + self.name = name + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + node = self.graph.add_node(ops.DummyOp(name=self.name), inputs=args) + return_ty = get_type(self.node) + assert return_ty is not None + return [node.add_out_port(ty) for ty in type_to_row(return_ty)] + + +class DunderChecker(CustomCallChecker): + """Call checker for builtin functions that call out to dunder instance methods""" + + dunder_name: str + num_args: int + + def __init__(self, dunder_name: str, num_args: int = 1): + assert num_args > 0 + self.dunder_name = dunder_name + self.num_args = num_args + + def _get_func(self, args: list[ast.expr]) -> tuple[list[ast.expr], CallableVariable]: + check_num_args(self.num_args, len(args), self.node) + fst, *rest = args + fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) + func = self.ctx.globals.get_instance_func(ty, self.dunder_name) + if func is None: + raise GuppyTypeError( + f"Builtin function `{self.func.name}` is not defined for argument of " + f"type `{ty}`", + self.node.args[0] if isinstance(self.node, ast.Call) else self.node, + ) + return [fst, *rest], func + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + args, func = self._get_func(args) + return func.synthesize_call(args, self.node, self.ctx) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + args, func = self._get_func(args) + return func.check_call(args, ty, self.node, self.ctx) + + +class CallableChecker(CustomCallChecker): + """Call checker for the builtin `callable` function""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + check_num_args(1, len(args), self.node) + [arg] = args + arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) + is_callable = ( + isinstance(ty, FunctionType) + or self.ctx.globals.get_instance_func(ty, "__call__") is not None + ) + const = with_loc(self.node, ast.Constant(value=is_callable)) + return const, BoolType() + + +class IntTruedivCompiler(CustomCallCompiler): + """Compiler for the `int.__truediv__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Int, Float + + # Compile `truediv` using float arithmetic + [left, right] = args + [left] = Int.__float__.compile_call([left], self.dfg, self.graph, self.globals, self.node) + [right] = Int.__float__.compile_call([right], self.dfg, self.graph, self.globals, self.node) + return Float.__truediv__.compile_call([left, right], self.dfg, self.graph, self.globals, self.node) + + +class FloatBoolCompiler(CustomCallCompiler): + """Compiler for the `float.__bool__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: bool(x) = (x != 0.0) + zero_const = self.graph.add_constant(float_value(0.0), get_type(self.node), self.dfg.node) + zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) + return Float.__ne__.compile_call( + [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node + ) + + +class FloatFloordivCompiler(CustomCallCompiler): + """Compiler for the `float.__floordiv__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: floordiv(x, y) = floor(truediv(x, y)) + [div] = Float.__truediv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [floor] = Float.__floor__.compile_call( + [div], self.dfg, self.graph, self.globals, self.node + ) + return [floor] + + +class FloatModCompiler(CustomCallCompiler): + """Compiler for the `float.__mod__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: mod(x, y) = x - (x // y) * y + [div] = Float.__floordiv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [mul] = Float.__mul__.compile_call( + [div, args[1]], self.dfg, self.graph, self.globals, self.node + ) + [sub] = Float.__sub__.compile_call( + [args[0], mul], self.dfg, self.graph, self.globals, self.node + ) + return [sub] + + +class FloatDivmodCompiler(CustomCallCompiler): + """Compiler for the `__divmod__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: divmod(x, y) = (div(x, y), mod(x, y)) + [div] = Float.__truediv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [mod] = Float.__mod__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] \ No newline at end of file diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py new file mode 100644 index 00000000..c684aac6 --- /dev/null +++ b/guppy/prelude/builtins.py @@ -0,0 +1,385 @@ +"""Guppy module for builtin types and operations.""" + +# mypy: disable-error-code="empty-body, misc, override, no-untyped-def" + +from guppy.custom import NoopCompiler, DefaultCallChecker +from guppy.decorator import guppy +from guppy.guppy_types import BoolType +from guppy.hugr import tys +from guppy.module import GuppyModule +from guppy.prelude._internal import logic_op, int_op, hugr_int_type, hugr_float_type, \ + float_op, CoercingChecker, ReversingChecker, IntTruedivCompiler, FloatBoolCompiler, \ + FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, \ + NotImplementedCompiler, DunderChecker, CallableChecker + + +builtins = GuppyModule("builtins", import_builtins=False) + + +@guppy.extend_type(builtins, BoolType) +class Bool: + + @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) + def __and__(self: bool, other: bool) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __bool__(self: bool) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ifrombool")) + def __int__(self: bool) -> int: + ... + + @guppy.hugr_op(builtins, logic_op("Or", [tys.BoundedNatArg(n=2)])) + def __or__(self: bool, other: bool) -> bool: + ... + + +@guppy.type(builtins, hugr_int_type, name="int") +class Int: + + @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) + def __abs__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iadd")) + def __add__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iand")) + def __and__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("itobool")) + def __bool__(self: int) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __ceil__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2)) + def __divmod__(self: int, other: int) -> tuple[int, int]: + ... + + @guppy.hugr_op(builtins, int_op("ieq")) + def __eq__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("convert_s", "arithmetic.conversions")) + def __float__(self: int) -> float: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __floor__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2)) + def __floordiv__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ige_s")) + def __ge__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("igt_s")) + def __gt__(self: int, other: int) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __int__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("inot")) + def __invert__(self: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ile_s")) + def __le__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ishl", num_params=2)) # TODO: RHS is unsigned + def __lshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ilt_s")) + def __lt__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2)) + def __mod__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imul")) + def __mul__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ine")) + def __ne__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ineg")) + def __neg__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ior")) + def __or__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __pos__(self: int) -> int: + ... + + @guppy.custom(builtins, NotImplementedCompiler("ipow"), DefaultCallChecker()) # TODO + def __pow__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker(DefaultCallChecker())) + def __radd__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker(DefaultCallChecker())) + def __rand__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker(DefaultCallChecker())) + def __rdivmod__(self: int, other: int) -> tuple[int, int]: + ... + + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker(DefaultCallChecker())) + def __rfloordiv__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ishl", num_params=2), ReversingChecker(DefaultCallChecker())) # TODO: RHS is unsigned + def __rlshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker(DefaultCallChecker())) + def __rmod__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker(DefaultCallChecker())) + def __rmul__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ior"), ReversingChecker(DefaultCallChecker())) + def __ror__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __round__(self: int) -> int: + ... + + @guppy.custom(builtins, NotImplementedCompiler("ipow"), ReversingChecker(DefaultCallChecker())) # TODO + def __rpow__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ishr", num_params=2), ReversingChecker(DefaultCallChecker())) # TODO: RHS is unsigned + def __rrshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ishr", num_params=2)) # TODO: RHS is unsigned + def __rshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker(DefaultCallChecker())) + def __rsub__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker(DefaultCallChecker())) + def __rtruediv__(self: int, other: int) -> float: + ... + + @guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker(DefaultCallChecker())) + def __rxor__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("isub")) + def __sub__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, IntTruedivCompiler()) + def __truediv__(self: int, other: int) -> float: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __trunc__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ixor")) + def __xor__(self: int, other: int) -> int: + ... + + +@guppy.type(builtins, hugr_float_type, name="float") +class Float: + + @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) + def __abs__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fadd"), CoercingChecker()) + def __add__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatBoolCompiler(), CoercingChecker()) + def __bool__(self: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fceil"), CoercingChecker()) + def __ceil__(self: float) -> float: + ... + + @guppy.custom(builtins, FloatDivmodCompiler(), CoercingChecker()) + def __divmod__(self: float, other: float) -> tuple[float, float]: + ... + + @guppy.hugr_op(builtins, float_op("feq"), CoercingChecker()) + def __eq__(self: float, other: float) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) + def __float__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("ffloor"), CoercingChecker()) + def __floor__(self: float) -> float: + ... + + @guppy.custom(builtins, FloatFloordivCompiler(), CoercingChecker()) + def __floordiv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fge"), CoercingChecker()) + def __ge__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fgt"), CoercingChecker()) + def __gt__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker()) + def __int__(self: float) -> int: + ... + + @guppy.hugr_op(builtins, float_op("fle"), CoercingChecker()) + def __le__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("flt"), CoercingChecker()) + def __lt__(self: float, other: float) -> bool: + ... + + @guppy.custom(builtins, FloatModCompiler(), CoercingChecker()) + def __mod__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fmul"), CoercingChecker()) + def __mul__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fne"), CoercingChecker()) + def __ne__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fneg"), CoercingChecker()) + def __neg__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) + def __pos__(self: float) -> float: + ... + + @guppy.custom(builtins, NotImplementedCompiler("fpow"), CoercingChecker()) # TODO + def __pow__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fadd"), ReversingChecker(CoercingChecker())) + def __radd__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatDivmodCompiler(), ReversingChecker(CoercingChecker())) + def __rdivmod__(self: float, other: float) -> tuple[float, float]: + ... + + @guppy.custom(builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker())) + def __rfloordiv__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatModCompiler(), ReversingChecker(CoercingChecker())) + def __rmod__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fmul"), ReversingChecker(CoercingChecker())) + def __rmul__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, NotImplementedCompiler("fround"), ReversingChecker(CoercingChecker())) # TODO + def __round__(self: float) -> float: + ... + + @guppy.custom(builtins, NotImplementedCompiler("fpow"), ReversingChecker(CoercingChecker())) # TODO + def __rpow__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fsub"), ReversingChecker(CoercingChecker())) + def __rsub__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fdiv"), ReversingChecker(CoercingChecker())) + def __rtruediv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fsub"), CoercingChecker()) + def __sub__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fdiv"), CoercingChecker()) + def __truediv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker()) + def __trunc__(self: float) -> float: + ... + + +@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False) +def abs(x): + ... + + +@guppy.custom(builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False) +def _bool(x): + ... + + +@guppy.custom(builtins, checker=CallableChecker(), higher_order_value=False) +def callable(x): + ... + + +@guppy.custom(builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False) +def divmod(x, y): + ... + + +@guppy.custom(builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False) +def _float(x, y): + ... + + +@guppy.custom(builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False) +def _int(x): + ... + + +@guppy.custom(builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False) +def pow(x, y): + ... + + +@guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) +def round(x): + ... + diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index c0fcea8d..fd1b9bc3 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -1,64 +1,61 @@ -"""Guppy standard extension for quantum operations.""" +"""Guppy standard module for quantum operations.""" # mypy: disable-error-code=empty-body +from guppy.decorator import guppy +from guppy.hugr import tys, ops from guppy.hugr.tys import TypeBound -from guppy.prelude import builtin -from guppy.extension import GuppyExtension, OpCompiler -from guppy.hugr import ops, tys +from guppy.module import GuppyModule -class QuantumOpCompiler(OpCompiler): - def __init__(self, op_name: str, ext: str = "quantum.tket2"): - super().__init__(ops.CustomOp(extension=ext, op_name=op_name, args=[])) +quantum = GuppyModule("quantum") -_hugr_qubit = tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any) +def quantum_op(op_name: str) -> ops.OpType: + """Utility method to create Hugr quantum ops.""" + return ops.CustomOp(extension="quantum.tket2", op_name=op_name, args=[]) -extension = GuppyExtension("quantum.tket2", dependencies=[builtin]) - - -@extension.type(_hugr_qubit, linear=True) +@guppy.type(quantum, tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), linear=True) class Qubit: pass -@extension.func(QuantumOpCompiler("H")) +@guppy.hugr_op(quantum, quantum_op("H")) def h(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("CX")) +@guppy.hugr_op(quantum, quantum_op("CX")) def cx(control: Qubit, target: Qubit) -> tuple[Qubit, Qubit]: ... -@extension.func(QuantumOpCompiler("RzF64")) +@guppy.hugr_op(quantum, quantum_op("RzF64")) def rz(q: Qubit, angle: float) -> Qubit: ... -@extension.func(QuantumOpCompiler("Measure")) +@guppy.hugr_op(quantum, quantum_op("Measure")) def measure(q: Qubit) -> tuple[Qubit, bool]: ... -@extension.func(QuantumOpCompiler("T")) +@guppy.hugr_op(quantum, quantum_op("T")) def t(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("Tdg")) +@guppy.hugr_op(quantum, quantum_op("Tdg")) def tdg(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("Z")) +@guppy.hugr_op(quantum, quantum_op("Z")) def z(q: Qubit) -> Qubit: ... -@extension.func(QuantumOpCompiler("X")) +@guppy.hugr_op(quantum, quantum_op("X")) def x(q: Qubit) -> Qubit: ... From 5fb37d00b1733597c2701e8f80b2b9fd5de6ea66 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:01:01 +0000 Subject: [PATCH 10/77] Remove unused files --- guppy/compiler.py | 207 --------------------- guppy/compiler_base.py | 251 ------------------------- guppy/expression.py | 310 ------------------------------- guppy/extension.py | 386 --------------------------------------- guppy/function.py | 240 ------------------------ guppy/prelude/boolean.py | 47 ----- guppy/prelude/builtin.py | 230 ----------------------- guppy/prelude/float.py | 260 -------------------------- guppy/prelude/integer.py | 287 ----------------------------- guppy/statement.py | 177 ------------------ 10 files changed, 2395 deletions(-) delete mode 100644 guppy/compiler.py delete mode 100644 guppy/compiler_base.py delete mode 100644 guppy/expression.py delete mode 100644 guppy/extension.py delete mode 100644 guppy/function.py delete mode 100644 guppy/prelude/boolean.py delete mode 100644 guppy/prelude/builtin.py delete mode 100644 guppy/prelude/float.py delete mode 100644 guppy/prelude/integer.py delete mode 100644 guppy/statement.py diff --git a/guppy/compiler.py b/guppy/compiler.py deleted file mode 100644 index da4f3f20..00000000 --- a/guppy/compiler.py +++ /dev/null @@ -1,207 +0,0 @@ -import ast -import functools -import inspect -import sys -import textwrap -from dataclasses import dataclass -from types import ModuleType -from typing import Optional, Any, Callable, Union - -from guppy.ast_util import is_empty_body -from guppy.compiler_base import Globals -from guppy.extension import GuppyExtension -from guppy.function import FunctionDefCompiler, DefinedFunction -from guppy.hugr.hugr import Hugr -from guppy.error import GuppyError, SourceLoc - - -def format_source_location( - source_lines: list[str], - loc: Union[ast.AST, ast.operator, ast.expr, ast.arg, ast.Name], - line_offset: int, - num_lines: int = 3, - indent: int = 4, -) -> str: - """Creates a pretty banner to show source locations for errors.""" - assert loc.end_col_offset is not None # TODO - s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() - s += "\n" + loc.col_offset * " " + (loc.end_col_offset - loc.col_offset) * "^" - s = textwrap.dedent(s).splitlines() - # Add line numbers - line_numbers = [ - str(line_offset + loc.lineno - i) + ":" for i in range(num_lines, 0, -1) - ] - longest = max(len(ln) for ln in line_numbers) - prefixes = [ln + " " * (longest - len(ln) + indent) for ln in line_numbers] - res = "".join(prefix + line + "\n" for prefix, line in zip(prefixes, s[:-1])) - res += (longest + indent) * " " + s[-1] - return res - - -@dataclass -class RawFunction: - pyfun: Callable[..., Any] - ast: ast.FunctionDef - source_lines: list[str] - line_offset: int - - -class GuppyModule(object): - """A Guppy module backed by a Hugr graph. - - Instances of this class can be used as a decorator to add functions to the module. - After all functions are added, `compile()` must be called to obtain the Hugr. - """ - - name: str - globals: Globals - - _func_defs: dict[str, RawFunction] - _func_decls: dict[str, RawFunction] - - def __init__(self, name: str): - self.name = name - self.globals = Globals.default() - self._func_defs = {} - self._func_decls = {} - - # Load all prelude extensions - import guppy.prelude.builtin - import guppy.prelude.boolean - import guppy.prelude.float - import guppy.prelude.integer - - self.load(guppy.prelude.builtin) - self.load(guppy.prelude.boolean) - self.load(guppy.prelude.float) - self.load(guppy.prelude.integer) - - def register_func(self, f: Callable[..., Any]) -> None: - """Registers a Python function as belonging to this Guppy module. - - This can be used for both function definitions and declarations. To mark a - declaration, the body of the function may only contain an ellipsis expression. - """ - func = self._parse(f) - if is_empty_body(func.ast): - self._func_decls[func.ast.name] = func - else: - self._func_defs[func.ast.name] = func - - def load(self, m: Union[ModuleType, GuppyExtension]) -> None: - """Loads a Guppy extension from a python module. - - This function must be called for names from the extension to become available in - the Guppy. - """ - if isinstance(m, GuppyExtension): - self.globals |= m.globals - else: - for ext in m.__dict__.values(): - if isinstance(ext, GuppyExtension): - self.globals |= ext.globals - - def _parse(self, f: Callable[..., Any]) -> RawFunction: - source_lines, line_offset = inspect.getsourcelines(f) - line_offset -= 1 - source = "".join(source_lines) # Lines already have trailing \n's - source = textwrap.dedent(source) - func_ast = ast.parse(source).body[0] - if not isinstance(func_ast, ast.FunctionDef): - raise GuppyError("Only functions can be placed in modules", func_ast) - if func_ast.name in self._func_defs: - raise GuppyError( - f"Module `{self.name}` already contains a function named `{func_ast.name}` " - f"(declared at {SourceLoc.from_ast(self._func_defs[func_ast.name].ast, line_offset)})", - func_ast, - ) - return RawFunction(f, func_ast, source_lines, line_offset) - - def compile(self, exit_on_error: bool = False) -> Optional[Hugr]: - """Compiles the module and returns the final Hugr.""" - graph = Hugr(self.name) - module_node = graph.set_root_name(self.name) - try: - # Generate nodes for all function definition and declarations and add them - # to the globals - defs = {} - for name, f in self._func_defs.items(): - func_ty = FunctionDefCompiler.validate_signature(f.ast, self.globals) - def_node = graph.add_def(func_ty, module_node, f.ast.name) - defs[name] = def_node - self.globals.values[name] = DefinedFunction( - name, def_node.out_port(0), f.ast - ) - for name, f in self._func_decls.items(): - func_ty = FunctionDefCompiler.validate_signature(f.ast, self.globals) - if not is_empty_body(f.ast): - raise GuppyError( - "Function declarations may not have a body.", f.ast.body[0] - ) - decl_node = graph.add_declare(func_ty, module_node, f.ast.name) - self.globals.values[name] = DefinedFunction( - name, decl_node.out_port(0), f.ast - ) - - # Now compile functions definitions - for name, f in self._func_defs.items(): - FunctionDefCompiler(graph, self.globals).compile_global( - f.ast, defs[name] - ) - return graph - - except GuppyError as err: - if err.location: - loc = err.location - line = f.line_offset + loc.lineno - print( - "Guppy compilation failed. " - f"Error in file {inspect.getsourcefile(f.pyfun)}:{line}\n", - file=sys.stderr, - ) - print( - format_source_location(f.source_lines, loc, f.line_offset + 1), - file=sys.stderr, - ) - else: - print( - "Guppy compilation failed. " - f"Error in file {inspect.getsourcefile(f.pyfun)}\n", - file=sys.stderr, - ) - print( - f"{err.__class__.__name__}: {err.get_msg(f.line_offset)}", - file=sys.stderr, - ) - if exit_on_error: - sys.exit(1) - return None - - -def guppy( - arg: Union[Callable[..., Any], GuppyModule] -) -> Union[Optional[Hugr], Callable[[Callable[..., Any]], Callable[..., Any]]]: - """Decorator to annotate Python functions as Guppy code. - - Optionally, the `GuppyModule` in which the function should be placed can be passed - to the decorator. - """ - if isinstance(arg, GuppyModule): - - def dec(f: Callable[..., Any]) -> Callable[..., Any]: - assert isinstance(arg, GuppyModule) - arg.register_func(f) - - @functools.wraps(f) - def dummy(*args: Any, **kwargs: Any) -> Any: - raise GuppyError( - "Guppy functions can only be called in a Guppy context" - ) - - return dummy - - return dec - else: - module = GuppyModule("module") - module.register_func(arg) - return module.compile(exit_on_error=False) diff --git a/guppy/compiler_base.py b/guppy/compiler_base.py deleted file mode 100644 index 02fd4651..00000000 --- a/guppy/compiler_base.py +++ /dev/null @@ -1,251 +0,0 @@ -import ast -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Iterator, Optional, Any, NamedTuple - -from guppy.ast_util import AstNode -from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType -from guppy.hugr.hugr import OutPortV, Hugr, DFContainingNode - - -ValueName = str -TypeName = str - - -@dataclass -class RawVariable: - """Class holding data associated with a variable. - - Besides the name and type, we also store an AST node where the variable was defined. - """ - - name: ValueName - ty: GuppyType - defined_at: Optional[AstNode] - - def __lt__(self, other: Any) -> bool: - # We define an ordering on variables that is used to determine in which order - # they are outputted from basic blocks. We need to output linear variables at - # the end, so we do a lexicographic ordering of linearity and name, exploiting - # the fact that `False < True` in Python. The only exception are return vars - # which must be outputted in order. - if not isinstance(other, Variable): - return NotImplemented - if is_return_var(self.name) and is_return_var(other.name): - return self.name < other.name - return (self.ty.linear, self.name) < (other.ty.linear, other.name) - - -@dataclass -class Variable(RawVariable): - """Represents a concrete variable during compilation. - - Compared to a `RawVariable`, each variable corresponds to a Hugr port. - """ - - port: OutPortV - used: Optional[AstNode] = None - - def __init__(self, name: str, port: OutPortV, defined_at: Optional[AstNode]): - super().__init__(name, port.ty, defined_at) - object.__setattr__(self, "port", port) - - -# A dictionary mapping names to live variables -VarMap = dict[ValueName, Variable] - - -@dataclass -class DFContainer: - """A dataflow graph under construction. - - This class is passed through the entire compilation pipeline and stores the node - whose dataflow child-graph is currently being constructed as well as all live - variables. Note that the variable map is mutated in-place and always reflects the - current compilation state. - """ - - node: DFContainingNode - variables: VarMap - - def __getitem__(self, item: str) -> Variable: - return self.variables[item] - - def __setitem__(self, key: str, value: Variable) -> None: - self.variables[key] = value - - def __iter__(self) -> Iterator[Variable]: - return iter(self.variables.values()) - - def __contains__(self, item: str) -> bool: - return item in self.variables - - def __copy__(self) -> "DFContainer": - # Make a copy of the var map so that mutating the copy doesn't - # mutate our variable mapping - return DFContainer(self.node, self.variables.copy()) - - def get_var(self, name: str) -> Optional[Variable]: - return self.variables.get(name, None) - - -@dataclass -class GlobalVariable(ABC, RawVariable): - """Represents a global module-level variable.""" - - @abstractmethod - def load( - self, graph: Hugr, parent: DFContainingNode, globals: "Globals", node: AstNode - ) -> OutPortV: - """Loads the global variable as a value into a local dataflow graph.""" - - -@dataclass -class GlobalFunction(GlobalVariable, ABC): - """Represents a global module-level function.""" - - ty: FunctionType - call_compiler: "CallCompiler" - - def compile_call( - self, - args: list[OutPortV], - dfg: DFContainer, - graph: Hugr, - globals: "Globals", - node: AstNode, - ) -> list[OutPortV]: - """Utility method that invokes the local `CallCompiler` to compile a function - call""" - self.call_compiler.setup(dfg, graph, globals, self, node) - return self.call_compiler.compile(args) - - def compile_call_raw( - self, - args: list[ast.expr], - dfg: DFContainer, - graph: Hugr, - globals: "Globals", - node: AstNode, - ) -> list[OutPortV]: - """Utility method that invokes the local `CallCompiler` to compile a function - call with raw argument AST nodes""" - self.call_compiler.setup(dfg, graph, globals, self, node) - return self.call_compiler.compile_raw(args) - - -class Globals(NamedTuple): - """Collection of names that are available on module-level. - - Separately stores names that are bound to values (i.e. module-level functions or - constants), to types, or to instance functions belonging to types. - """ - - values: dict[ValueName, GlobalVariable] - types: dict[TypeName, type[GuppyType]] - instance_funcs: dict[tuple[TypeName, ValueName], GlobalFunction] - - @staticmethod - def default() -> "Globals": - """Generates a `Globals` instance that is populated with all core types""" - tys: dict[str, type[GuppyType]] = { - FunctionType.name: FunctionType, - TupleType.name: TupleType, - SumType.name: SumType, - } - return Globals({}, tys, {}) - - def get_instance_func(self, ty: GuppyType, name: str) -> Optional[GlobalFunction]: - """Looks up an instance function with a given name for a type""" - return self.instance_funcs.get((ty.name, name), None) - - def __or__(self, other: "Globals") -> "Globals": - return Globals( - self.values | other.values, - self.types | other.types, - self.instance_funcs | other.instance_funcs, - ) - - def __ior__(self, other: "Globals") -> "Globals": - self.values.update(other.values) - self.types.update(other.types) - self.instance_funcs.update(other.instance_funcs) - return self - - -class CompilerBase(ABC): - """Base class for the Guppy compiler.""" - - graph: Hugr - globals: Globals - - def __init__(self, graph: Hugr, globals: Globals) -> None: - self.graph = graph - self.globals = globals - - -class CallCompiler(ABC): - """Abstract base class for function call compilers.""" - - dfg: DFContainer - graph: Hugr - globals: Globals - func: GlobalFunction - node: AstNode - - @abstractmethod - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - """Compiles a function call with the given argument ports. - - Returns a row of output ports that are returned by the function. - """ - ... - - def compile_raw(self, args: list[ast.expr]) -> list[OutPortV]: - """Compiles a function call with raw argument AST nodes. - - The default implementation invokes the `ExpressionCompiler` to first compile the - arguments and then calls out to `compile`. - """ - from guppy.expression import ExpressionCompiler - - expr_compiler = ExpressionCompiler(self.graph, self.globals) - return self.compile([expr_compiler.compile(a, self.dfg) for a in args]) - - @property - def parent(self) -> DFContainingNode: - """The parent node of the current dataflow graph""" - return self.dfg.node - - def setup( - self, - dfg: DFContainer, - graph: Hugr, - globals: "Globals", - func: GlobalFunction, - node: AstNode, - ) -> None: - """Initialises the parameters of the call compiler. - - Must be called before trying to compile. - """ - self.dfg = dfg - self.graph = graph - self.globals = globals - self.func = func - self.node = node - - -def return_var(n: int) -> str: - """Name of the dummy variable for the n-th return value of a function. - - During compilation, we treat return statements like assignments of dummy variables. - For example, the statement `return e0, e1, e2` is treated like `%ret0 = e0 ; %ret1 = - e1 ; %ret2 = e2`. This way, we can reuse our existing mechanism for passing of live - variables between basic blocks.""" - return f"%ret{n}" - - -def is_return_var(x: str) -> bool: - """Checks whether the given name is a dummy return variable.""" - return x.startswith("%ret") diff --git a/guppy/expression.py b/guppy/expression.py deleted file mode 100644 index 44e3badd..00000000 --- a/guppy/expression.py +++ /dev/null @@ -1,310 +0,0 @@ -import ast -from typing import Any, Optional - -from guppy.ast_util import AstVisitor, AstNode -from guppy.compiler_base import ( - CompilerBase, - DFContainer, - GlobalFunction, - GlobalVariable, -) -from guppy.error import InternalGuppyError, GuppyTypeError, GuppyError -from guppy.guppy_types import FunctionType, GuppyType -from guppy.hugr import val, ops -from guppy.hugr.hugr import OutPortV - -# Mapping from unary AST op to dunder method and display name -unary_table: dict[type[AstNode], tuple[str, str]] = { - ast.UAdd: ("__pos__", "+"), - ast.USub: ("__neg__", "-"), - ast.Invert: ("__invert__", "~"), -} - -# Mapping from binary AST op to left dunder method, right dunder method and display name -binary_table: dict[type[AstNode], tuple[str, str, str]] = { - ast.Add: ("__add__", "__radd__", "+"), - ast.Sub: ("__sub__", "__rsub__", "-"), - ast.Mult: ("__mul__", "__rmul__", "*"), - ast.Div: ("__truediv__", "__rtruediv__", "/"), - ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), - ast.Mod: ("__mod__", "__rmod__", "%"), - ast.Pow: ("__pow__", "__rpow__", "**"), - ast.LShift: ("__lshift__", "__rlshift__", "<<"), - ast.RShift: ("__rshift__", "__rrshift__", ">>"), - ast.BitOr: ("__or__", "__ror__", "||"), - ast.BitXor: ("__xor__", "__rxor__", "^"), - ast.BitAnd: ("__and__", "__rand__", "&&"), - ast.MatMult: ("__matmul__", "__rmatmul__", "@"), - ast.Eq: ("__eq__", "__eq__", "=="), - ast.NotEq: ("__neq__", "__neq__", "!="), - ast.Lt: ("__lt__", "__gt__", "<"), - ast.LtE: ("__le__", "__ge__", "<="), - ast.Gt: ("__gt__", "__lt__", ">"), - ast.GtE: ("__ge__", "__le__", ">="), -} - - -def expr_to_row(expr: ast.expr) -> list[ast.expr]: - """Turns an expression into a row expressions by unpacking top-level tuples.""" - return expr.elts if isinstance(expr, ast.Tuple) else [expr] - - -class ExpressionCompiler(CompilerBase, AstVisitor[OutPortV]): - """A compiler from Python expressions to Hugr.""" - - dfg: DFContainer - - def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: - """Compiles an expression and returns a single port holding the output value.""" - self.dfg = dfg - with self.graph.parent(dfg.node): - res = self.visit(expr) - return res - - def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: - """Compiles a row expression and returns a list of ports, one for - each value in the row. - - On Python-level, we treat tuples like rows on top-level. However, - nested tuples are treated like regular Guppy tuples. - """ - return [self.compile(e, dfg) for e in expr_to_row(expr)] - - def _is_global_var(self, x: str) -> Optional[GlobalVariable]: - """Checks if the argument references a global variable. - - Returns the variable if it exists, otherwise `None`. - """ - if x in self.globals.values and x not in self.dfg: - return self.globals.values[x] - return None - - def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> Any: - raise GuppyError("Expression not supported", node) - - def visit_Constant(self, node: ast.Constant) -> OutPortV: - if val_ty := python_value_to_hugr(node.value): - const = self.graph.add_constant(*val_ty).out_port(0) - return self.graph.add_load_constant(const).out_port(0) - raise GuppyError("Unsupported constant expression", node) - - def visit_Name(self, node: ast.Name) -> OutPortV: - x = node.id - if x in self.dfg: - var = self.dfg[x] - if var.ty.linear and var.used is not None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - node, - [var.used], - ) - var.used = node - return self.dfg[x].port - elif x in self.globals.values: - return self.globals.values[x].load( - self.graph, self.dfg.node, self.globals, node - ) - raise InternalGuppyError( - f"Variable `{x}` is not defined in ExpressionCompiler. This should have " - f"been caught by program analysis!" - ) - - def visit_JoinedString(self, node: ast.JoinedStr) -> OutPortV: - raise GuppyError("Guppy does not support formatted strings", node) - - def visit_Tuple(self, node: ast.Tuple) -> OutPortV: - return self.graph.add_make_tuple( - inputs=[self.visit(e) for e in node.elts] - ).out_port(0) - - def visit_List(self, node: ast.List) -> OutPortV: - raise NotImplementedError() - - def visit_Set(self, node: ast.Set) -> OutPortV: - raise NotImplementedError() - - def visit_Dict(self, node: ast.Dict) -> OutPortV: - raise NotImplementedError() - - def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: - arg = self.visit(node.operand) - - # Special case for the `not` operation since it is not implemented via a dunder - # method or control-flow - if isinstance(node.op, ast.Not): - from guppy.prelude.builtin import BoolType - - func = self.globals.get_instance_func(arg.ty, "__bool__") - if func is None: - raise GuppyTypeError( - f"Expression of type `{arg.ty}` cannot be interpreted as a `bool`", - node.operand, - ) - [arg] = func.compile_call([arg], self.dfg, self.graph, self.globals, node) - return self.graph.add_node( - ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] - ).add_out_port(BoolType()) - - # Compile all other unary expressions by calling out to instance dunder methods - op, display_name = unary_table[node.op.__class__] - func = self.globals.get_instance_func(arg.ty, op) - if func is None: - raise GuppyTypeError( - f"Unary operator `{display_name}` not defined for argument of type " - f" `{arg.ty}`", - node.operand, - ) - [res] = func.compile_call([arg], self.dfg, self.graph, self.globals, node) - return res - - def _compile_binary( - self, left_expr: AstNode, right_expr: AstNode, op: AstNode, node: AstNode - ) -> OutPortV: - """Helper method to compile binary operators by calling out to dunder methods. - - For example, first try calling `__add__` on the left operand. If that fails, try - `__radd__` on the right operand. - """ - if op.__class__ not in binary_table: - raise GuppyError("This binary operation is not supported by Guppy.") - lop, rop, display_name = binary_table[op.__class__] - left, right = self.visit(left_expr), self.visit(right_expr) - - if func := self.globals.get_instance_func(left.ty, lop): - try: - [ret] = func.compile_call( - [left, right], self.dfg, self.graph, self.globals, node - ) - return ret - except GuppyError: - pass - - if func := self.globals.get_instance_func(right.ty, lop): - try: - [ret] = func.compile_call( - [left, right], self.dfg, self.graph, self.globals, node - ) - return ret - except GuppyError: - pass - - raise GuppyTypeError( - f"Binary operator `{display_name}` not defined for arguments of type " - f"`{left.ty}` and `{right.ty}`", - node, - ) - - def visit_BinOp(self, node: ast.BinOp) -> OutPortV: - return self._compile_binary(node.left, node.right, node.op, node) - - def visit_Compare(self, node: ast.Compare) -> OutPortV: - if len(node.comparators) != 1 or len(node.ops) != 1: - raise InternalGuppyError( - "BB contains chained comparison. Should have been removed during CFG " - "construction." - ) - left_expr, [op], [right_expr] = node.left, node.ops, node.comparators - return self._compile_binary(left_expr, right_expr, op, node) - - def visit_Call(self, node: ast.Call) -> OutPortV: - func = node.func - if len(node.keywords) > 0: - raise GuppyError( - "Argument passing by keyword is not supported", node.keywords[0] - ) - - # Special case for calls of global module-level functions. This also handles - # calls of extension functions. - if ( - isinstance(func, ast.Name) - and (f := self._is_global_var(func.id)) - and isinstance(f, GlobalFunction) - ): - returns = f.compile_call_raw( - node.args, self.dfg, self.graph, self.globals, node - ) - - # Otherwise, compile the function like any other expression - else: - port = self.visit(func) - args = [self.visit(arg) for arg in node.args] - if isinstance(port.ty, FunctionType): - type_check_call(port.ty, args, node) - call = self.graph.add_indirect_call(port, args) - returns = [call.out_port(i) for i in range(len(port.ty.returns))] - elif f := self.globals.get_instance_func(port.ty, "__call__"): - returns = f.compile_call(args, self.dfg, self.graph, self.globals, node) - else: - raise GuppyTypeError(f"Expected function type, got `{port.ty}`", func) - - # Group outputs into tuple - if len(returns) != 1: - return self.graph.add_make_tuple(inputs=returns).out_port(0) - return returns[0] - - def visit_NamedExpr(self, node: ast.NamedExpr) -> OutPortV: - raise InternalGuppyError( - "BB contains `NamedExpr`. Should have been removed during CFG" - f"construction: `{ast.unparse(node)}`" - ) - - def visit_BoolOp(self, node: ast.BoolOp) -> OutPortV: - raise InternalGuppyError( - "BB contains `BoolOp`. Should have been removed during CFG construction: " - f"`{ast.unparse(node)}`" - ) - - def visit_IfExp(self, node: ast.IfExp) -> OutPortV: - raise InternalGuppyError( - "BB contains `IfExp`. Should have been removed during CFG construction: " - f"`{ast.unparse(node)}`" - ) - - -def check_num_args(exp: int, act: int, node: AstNode) -> None: - """Checks that the correct number of arguments have been passed to a function.""" - if act < exp: - raise GuppyTypeError( - f"Not enough arguments passed (expected {exp}, got {act})", node - ) - if exp < act: - if isinstance(node, ast.Call): - raise GuppyTypeError("Unexpected argument", node.args[exp]) - raise GuppyTypeError( - f"Too many arguments passed (expected {exp}, got {act})", node - ) - - -def type_check_call(func_ty: FunctionType, args: list[OutPortV], node: AstNode) -> None: - """Type-checks the arguments for a function call.""" - check_num_args(len(func_ty.args), len(args), node) - for i, port in enumerate(args): - if port.ty != func_ty.args[i]: - raise GuppyTypeError( - f"Expected argument of type `{func_ty.args[i]}`, got `{port.ty}`", - node.args[i] if isinstance(node, ast.Call) else node, - ) - - -def python_value_to_hugr(v: Any) -> Optional[tuple[val.Value, GuppyType]]: - """Turns a Python value into a Hugr value together with its type. - - Returns None if the Python value cannot be represented in Guppy. - """ - from guppy.prelude.builtin import ( - IntType, - BoolType, - FloatType, - int_value, - bool_value, - float_value, - ) - - if isinstance(v, bool): - return bool_value(v), BoolType() - elif isinstance(v, int): - return int_value(v), IntType() - elif isinstance(v, float): - return float_value(v), FloatType() - return None diff --git a/guppy/extension.py b/guppy/extension.py deleted file mode 100644 index a19792a7..00000000 --- a/guppy/extension.py +++ /dev/null @@ -1,386 +0,0 @@ -import ast -import builtins -import inspect -import textwrap -from dataclasses import dataclass, field -from types import ModuleType -from typing import Optional, Callable, Any, Union, Sequence - -from guppy.ast_util import AstNode, is_empty_body -from guppy.compiler_base import ( - GlobalFunction, - Globals, - CallCompiler, - DFContainer, -) -from guppy.error import GuppyError, InternalGuppyError -from guppy.expression import type_check_call -from guppy.function import FunctionDefCompiler -from guppy.guppy_types import FunctionType, GuppyType -from guppy.hugr import ops, tys as tys -from guppy.hugr.hugr import Hugr, DFContainingNode, OutPortV, Node, DFContainingVNode - - -class ExtensionDefinitionError(Exception): - """Exception indicating a failure while defining an extension.""" - - def __init__(self, msg: str, extension: "GuppyExtension") -> None: - super().__init__( - f"Definition of extension `{extension.name}` is invalid: {msg}" - ) - - -@dataclass -class ExtensionFunction(GlobalFunction): - """A custom function to extend Guppy with functionality. - - Must be provided with a `CallCompiler` that handles compilation of calls to the - extension function. This allows for full flexibility in how extensions are compiled. - Additionally, it can be specified whether the function can be used as a value in a - higher-order context. - """ - - call_compiler: CallCompiler - higher_order_value: bool = True - - _defined: dict[Node, DFContainingVNode] = field(default_factory=dict, init=False) - - def load( - self, graph: Hugr, parent: DFContainingNode, globals: Globals, node: AstNode - ) -> OutPortV: - """Loads the extension function as a value into a local dataflow graph. - - This will place a `FunctionDef` node into the Hugr module if one for this - function doesn't already exist and loads it into the DFG. This operation will - fail if the extension function is marked as not supporting higher-order usage. - """ - if not self.higher_order_value: - raise GuppyError( - "This function does not support usage in a higher-order context", - node, - ) - - # Find the module node by walking up the hierarchy - module: Node = parent - while not isinstance(module.op, ops.Module): - if module.parent is None: - raise InternalGuppyError( - "Encountered node that is not contained in a module." - ) - module = module.parent - - # If the function has not yet been loaded in this module, we first have to - # define it. We create a `FunctionDef` that takes some inputs, compiles a call - # to the function, and returns the results. - if module not in self._defined: - def_node = graph.add_def(self.ty, module, self.name) - inp = graph.add_input(list(self.ty.args), parent=def_node) - returns = self.compile_call( - [inp.out_port(i) for i in range(len(self.ty.args))], - DFContainer(def_node, {}), - graph, - globals, - node, - ) - graph.add_output(returns, parent=def_node) - self._defined[module] = def_node - - # Finally, load the function into the local DFG - return graph.add_load_constant( - self._defined[module].out_port(0), parent - ).out_port(0) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - raise GuppyError("Tried to call Guppy function in a Python context") - - -class UntypedExtensionFunction(ExtensionFunction): - """An extension function that does not require a signature. - - As a result, functions like this cannot be used in a higher-order context. - """ - - def __init__( - self, name: str, defined_at: Optional[AstNode], call_compiler: CallCompiler - ) -> None: - self.name = name - self.defined_at = defined_at - self.higher_order = False - self.call_compiler = call_compiler - - @property # type: ignore - def ty(self) -> FunctionType: - raise InternalGuppyError( - "Tried to access signature from untyped extension function" - ) - - -class GuppyExtension: - """A Guppy extension. - - Consists of a collection of types, extension functions, and instance functions. - Note that extensions can also declare instance functions for types that are not - defined in this extension. - """ - - name: str - - # Globals for all new names defined by this extension - globals: Globals - - # Globals for this extension including core types and all dependencies - _all_globals: Globals - - def __init__(self, name: str, dependencies: Sequence[ModuleType]) -> None: - """Creates a new empty Guppy extension. - - If the extension uses types from other extensions (for example the `builtin.py` - extensions from the prelude), they have to be passed as dependencies. - """ - - self.name = name - self.globals = Globals({}, {}, {}) - self._all_globals = Globals.default() - - for module in dependencies: - exts = [ - obj - for obj in module.__dict__.values() - if isinstance(obj, GuppyExtension) - ] - if len(exts) == 0: - raise ExtensionDefinitionError( - f"Dependency module `{module.__name__}` does not contain a Guppy extension", - self, - ) - for ext in exts: - self._all_globals |= ext.globals - - def register_type(self, name: str, ty: type[GuppyType]) -> None: - """Registers an existing `GuppyType` subclass with this extension.""" - self.globals.types[name] = ty - self._all_globals.types[name] = ty - - def register_func(self, name: str, func: "ExtensionFunction") -> None: - """Registers an existing `ExtensionFunction` with this extension.""" - self.globals.values[name] = func - - def register_instance_func( - self, ty: type[GuppyType], name: str, func: "ExtensionFunction" - ) -> None: - """Registers an existing function as an instance function for the type with the - given name.""" - self.globals.instance_funcs[ty.name, name] = func - - def new_type( - self, name: str, hugr_repr: tys.SimpleType, linear: bool = False - ) -> type[GuppyType]: - """Creates a new type. - - Requires the static Hugr translation of the type. Additionally, the type can be - marked as linear. - """ - _name = name - - class NewType(GuppyType): - name = _name - - @staticmethod - def build( - *args: GuppyType, node: Union[ast.Name, ast.Subscript] - ) -> "GuppyType": - # At the moment, extension types don't support type arguments. - if len(args) > 0: - raise GuppyError( - f"Type `{name}` does not accept type parameters.", node - ) - return NewType() - - @property - def linear(self) -> bool: - return linear - - def to_hugr(self) -> tys.SimpleType: - return hugr_repr - - def __eq__(self, other: Any) -> bool: - return isinstance(other, NewType) - - def __str__(self) -> str: - return name - - NewType.__name__ = NewType.__qualname__ = name - self.register_type(name, NewType) - return NewType - - def type( - self, - hugr_repr: tys.SimpleType, - alias: Optional[str] = None, - linear: bool = False, - ) -> Callable[[type], type]: - """Class decorator to annotate a new Guppy type. - - Requires the static Hugr translation of the type. Additionally, the type can be - marked as linear and an alias can be provided to be used in place of the class - name. - """ - - def decorator(cls: type) -> type: - return self.new_type(alias or cls.__name__, hugr_repr, linear) - - return decorator - - def new_func( - self, - name: str, - call_compiler: CallCompiler, - signature: Optional[FunctionType] = None, - higher_order_value: bool = True, - instance: Optional[builtins.type[GuppyType]] = None, - ) -> ExtensionFunction: - """Creates a new extension function. - - Passing a `GuppyType` with `instance=...` marks the function as an instance - function for the given type. A type signature may be omitted if higher-order - usage of the function is disabled. - """ - func: ExtensionFunction - if signature is None: - if higher_order_value: - raise ExtensionDefinitionError( - "Signature may only be omitted if `higher_order=False` is set", - self, - ) - func = UntypedExtensionFunction(name, None, call_compiler) # TODO: Location - else: - func = ExtensionFunction( - name, - signature, - None, - call_compiler, - higher_order_value, # TODO: Location - ) - if instance is not None: - self.register_instance_func(instance, name, func) - else: - self.register_func(name, func) - - return func - - def func( - self, - call_compiler: CallCompiler, - alias: Optional[str] = None, - higher_order_value: bool = True, - instance: Optional[builtins.type[GuppyType]] = None, - ) -> Callable[[Callable[..., Any]], ExtensionFunction]: - """Decorator to annotate a new extension function. - - Passing a `GuppyType` with `instance=...` marks the function as an instance - function for the given type. The type signature is extracted from the Python - type annotations on the function. They may only be omitted if higher-order - usage of the function is disabled. - """ - - def decorator(f: Callable[..., Any]) -> ExtensionFunction: - func_ast, ty = self._parse_decl(f) - name = alias or func_ast.name - return self.new_func(name, call_compiler, ty, higher_order_value, instance) - - return decorator - - def _parse_decl( - self, f: Callable[..., Any] - ) -> tuple[ast.FunctionDef, Optional[FunctionType]]: - """Helper method to parse a function into an AST. - - Also returns the function type extracted from the type annotations if they are - provided. - """ - source = textwrap.dedent(inspect.getsource(f)) - func_ast = ast.parse(source).body[0] - if not isinstance(func_ast, ast.FunctionDef): - raise ExtensionDefinitionError( - "Only functions may be annotated using `@extension`", self - ) - if not is_empty_body(func_ast): - raise ExtensionDefinitionError( - "Body of declared extension functions must be empty", self - ) - # Return None if annotations are missing - if not func_ast.returns or not all( - arg.annotation for arg in func_ast.args.args - ): - return func_ast, None - - return func_ast, FunctionDefCompiler.validate_signature( - func_ast, self._all_globals - ) - - -class OpCompiler(CallCompiler): - """Compiler for calls that can be implemented via a single Hugr op. - - Performs type checking against the signature of the function and inserts the - specified op into the graph. - """ - - op: ops.OpType - signature: Optional[FunctionType] = None - - def __init__(self, op: ops.OpType, signature: Optional[FunctionType] = None): - self.op = op - self.signature = signature - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - func_ty = self.signature or self.func.ty - type_check_call(func_ty, args, self.node) - leaf = self.graph.add_node(self.op.copy(), inputs=args, parent=self.parent) - return [leaf.add_out_port(ty) for ty in func_ty.returns] - - -class IdOpCompiler(CallCompiler): - """Compiler for calls that are no-ops. - - Compiles a call by directly returning the arguments. Type checking can be disabled - by passing `type_check=False`. - """ - - type_check: bool - - def __init__(self, type_check: bool = True): - self.type_check = type_check - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - if self.type_check: - func_ty = self.func.ty - type_check_call(func_ty, args, self.node) - return args - - -class Reversed(CallCompiler): - """Call compiler that reverses the arguments and calls out to a different compiler. - - Useful to implement the right-hand version of arithmetic functions, e.g. `__radd__`. - """ - - cc: CallCompiler - - def __init__(self, cc: CallCompiler): - self.cc = cc - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - self.cc.setup(self.dfg, self.graph, self.globals, self.func, self.node) - return self.cc.compile(list(reversed(args))) - - -class NotImplementedCompiler(CallCompiler): - """Call compiler that raises an error when the function is called. - - Should be used to inform users that a function that would normally be available in - Python is not yet implemented in Guppy. - """ - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - raise GuppyError("Operation is not yet implemented", self.node) diff --git a/guppy/function.py b/guppy/function.py deleted file mode 100644 index 33fae04c..00000000 --- a/guppy/function.py +++ /dev/null @@ -1,240 +0,0 @@ -import ast - -from guppy.ast_util import return_nodes_in_ast, AstNode -from guppy.cfg.bb import BB, NestedFunctionDef -from guppy.cfg.builder import CFGBuilder -from guppy.compiler_base import ( - CompilerBase, - RawVariable, - DFContainer, - Globals, - GlobalFunction, - CallCompiler, -) -from guppy.error import GuppyError -from guppy.expression import type_check_call -from guppy.guppy_types import ( - FunctionType, - type_row_from_ast, - type_from_ast, -) -from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode, DFContainingNode - - -class DefinedFunction(GlobalFunction): - ty: FunctionType - defined_at: ast.FunctionDef - port: OutPortV - - def __init__(self, name: str, port: OutPortV, defined_at: ast.FunctionDef): - assert isinstance(port.ty, FunctionType) - super().__init__(name, port.ty, defined_at, self.DefCallCompiler()) - self.port = port - - def load( - self, graph: Hugr, parent: DFContainingNode, globals: Globals, node: AstNode - ) -> OutPortV: - """Loads the function as a value into a local dataflow graph.""" - return graph.add_load_constant(self.port, parent).out_port(0) - - class DefCallCompiler(CallCompiler): - """Compiler for calls to defined functions.""" - - func: "DefinedFunction" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # Defined functions can be called using a regular direct call op - type_check_call(self.func.ty, args, self.node) - call = self.graph.add_call(self.func.port, args, self.parent) - return [call.out_port(i) for i in range(len(self.func.ty.returns))] - - -class FunctionDefCompiler(CompilerBase): - cfg_builder: CFGBuilder - - def __init__(self, graph: Hugr, globals: Globals): - super().__init__(graph, globals) - self.cfg_builder = CFGBuilder() - - @staticmethod - def validate_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: - """Checks the signature of a function definition and returns the corresponding - Guppy type.""" - if len(func_def.args.posonlyargs) != 0: - raise GuppyError( - "Positional-only parameters not supported", func_def.args.posonlyargs[0] - ) - if len(func_def.args.kwonlyargs) != 0: - raise GuppyError( - "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] - ) - if func_def.args.vararg is not None: - raise GuppyError("*args not supported", func_def.args.vararg) - if func_def.args.kwarg is not None: - raise GuppyError("**kwargs not supported", func_def.args.kwarg) - if func_def.returns is None: - # TODO: Error location is incorrect - if all(r.value is None for r in return_nodes_in_ast(func_def)): - raise GuppyError( - "Return type must be annotated. Try adding a `-> None` annotation.", - func_def, - ) - raise GuppyError("Return type must be annotated", func_def) - - arg_tys = [] - arg_names = [] - for i, arg in enumerate(func_def.args.args): - if arg.annotation is None: - raise GuppyError("Argument type must be annotated", arg) - ty = type_from_ast(arg.annotation, globals) - arg_tys.append(ty) - arg_names.append(arg.arg) - - ret_type_row = type_row_from_ast(func_def.returns, globals) - return FunctionType(arg_tys, ret_type_row.tys, arg_names) - - def compile_global( - self, - func_def: ast.FunctionDef, - def_node: DFContainingVNode, - ) -> DefinedFunction: - """Compiles a top-level function definition.""" - func_ty = self.validate_signature(func_def, self.globals) - args = func_def.args.args - - cfg = self.cfg_builder.build(func_def.body, len(func_ty.returns), self.globals) - - def_input = self.graph.add_input(parent=def_node) - cfg_node = self.graph.add_cfg( - def_node, inputs=[def_input.add_out_port(ty) for ty in func_ty.args] - ) - assert func_ty.arg_names is not None - input_sig = [ - RawVariable(x, ty, loc) - for x, ty, loc in zip(func_ty.arg_names, func_ty.args, args) - ] - cfg.compile( - self.graph, - input_sig, - list(func_ty.returns), - cfg_node, - self.globals, - ) - - # Add final output node for the def block - self.graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in func_ty.returns], - parent=def_node, - ) - - return DefinedFunction(func_def.name, def_node.out_port(0), func_def) - - def compile_local( - self, - func_def: NestedFunctionDef, - dfg: DFContainer, - bb: BB, - ) -> DefinedFunction: - """Compiles a local (nested) function definition.""" - func_ty = self.validate_signature(func_def, self.globals) - args = func_def.args.args - assert func_ty.arg_names is not None - - # We've already computed the CFG for this function while computing the CFG of - # the enclosing function - cfg = func_def.cfg - - # Find captured variables - parent_cfg = bb.cfg - def_ass_before = set(func_ty.arg_names) | dfg.variables.keys() - maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] - cfg.analyze(len(func_ty.returns), def_ass_before, maybe_ass_before) - captured = [ - dfg[x] - for x in cfg.live_before[cfg.entry_bb] - if x not in func_ty.arg_names and x in dfg - ] - - # Captured variables may not be linear - for v in captured: - if v.ty.linear: - x = v.name - using_bb = cfg.live_before[cfg.entry_bb][x] - raise GuppyError( - f"Variable `{x}` with linear type `{v.ty}` may not be used here " - f"because it was defined in an outer scope (at {{0}})", - using_bb.vars.used[x], - [v.defined_at], - ) - - # Captured variables may never be assigned to - for bb in cfg.bbs: - for v in captured: - x = v.name - if x in bb.vars.assigned: - raise GuppyError( - f"Variable `{x}` defined in an outer scope (at {{0}}) may not " - f"be assigned to", - bb.vars.assigned[x], - [v.defined_at], - ) - - # Prepend captured variables to the function arguments - closure_ty = FunctionType( - [v.ty for v in captured] + list(func_ty.args), - func_ty.returns, - [v.name for v in captured] + list(func_ty.arg_names), - ) - - def_node = self.graph.add_def(closure_ty, dfg.node, func_def.name) - def_input = self.graph.add_input(parent=def_node) - input_ports = [def_input.add_out_port(ty) for ty in closure_ty.args] - input_row = captured + [ - RawVariable(x, ty, loc) - for x, ty, loc in zip(func_ty.arg_names, func_ty.args, args) - ] - - # If we have captured variables and the body contains a recursive occurrence of - # the function itself, then we pass a version of the function with applied - # captured arguments as an extra argument. - if len(captured) > 0 and func_def.name in cfg.live_before[cfg.entry_bb]: - loaded = self.graph.add_load_constant(def_node.out_port(0), parent=def_node) - partial = self.graph.add_partial( - loaded.out_port(0), args=input_ports[: len(captured)], parent=def_node - ) - input_ports += [partial.out_port(0)] - input_row += [RawVariable(func_def.name, func_ty, func_def)] - global_values = self.globals.values - # Otherwise, we can treat the function like a normal global variable - else: - global_values = self.globals.values | { - func_def.name: DefinedFunction( - func_def.name, def_node.out_port(0), func_def - ) - } - globals = Globals( - global_values, self.globals.types, self.globals.instance_funcs - ) - - cfg_node = self.graph.add_cfg(def_node, inputs=input_ports) - cfg.compile(self.graph, input_row, list(func_ty.returns), cfg_node, globals) - - # Add final output node for the def block - self.graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in func_ty.returns], - parent=def_node, - ) - - # Finally, add partial application node to supply the captured arguments - loaded = self.graph.add_load_constant(def_node.out_port(0), parent=dfg.node) - if len(captured) > 0: - # TODO: We can probably get rid of the load here once we have a resource - # that supports partial application, instead of using a dummy Op here. - partial = self.graph.add_partial( - loaded.out_port(0), args=[v.port for v in captured], parent=dfg.node - ) - port = partial.out_port(0) - else: - port = loaded.out_port(0) - - return DefinedFunction(func_def.name, port, func_def) diff --git a/guppy/prelude/boolean.py b/guppy/prelude/boolean.py deleted file mode 100644 index f124d503..00000000 --- a/guppy/prelude/boolean.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Guppy standard extension for bool operations.""" - -# mypy: disable-error-code=empty-body - -from guppy.prelude import builtin -from guppy.prelude.builtin import BoolType -from guppy.extension import ( - GuppyExtension, - OpCompiler, - IdOpCompiler, - NotImplementedCompiler, -) -from guppy.hugr import ops -from guppy.prelude.integer import IntOpCompiler - - -class BoolOpCompiler(OpCompiler): - def __init__(self, op_name: str): - super().__init__(ops.CustomOp(extension="logic", op_name=op_name, args=[])) - - -ext = GuppyExtension("boolean", [builtin]) - - -@ext.func(BoolOpCompiler("And"), instance=BoolType) -def __and__(self: bool, other: bool) -> bool: - ... - - -@ext.func(IdOpCompiler(), instance=BoolType) -def __bool__(self: bool) -> bool: - ... - - -@ext.func(IntOpCompiler("ifrombool"), instance=BoolType) -def __int__(self: bool) -> int: - ... - - -@ext.func(BoolOpCompiler("Or"), instance=BoolType) -def __or__(self: bool, other: bool) -> bool: - ... - - -@ext.func(NotImplementedCompiler(), instance=BoolType) # TODO -def __str__(self: int) -> str: - ... diff --git a/guppy/prelude/builtin.py b/guppy/prelude/builtin.py deleted file mode 100644 index 5134ba78..00000000 --- a/guppy/prelude/builtin.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Guppy standard extension for builtin types and methods. - -Instance methods for builtin types are defined in their own files -""" - -import ast -from typing import Union, Literal, Any - -from pydantic import BaseModel - -from guppy.compiler_base import CallCompiler -from guppy.error import GuppyError, GuppyTypeError -from guppy.expression import check_num_args -from guppy.extension import GuppyExtension, NotImplementedCompiler -from guppy.guppy_types import SumType, TupleType, GuppyType, FunctionType -from guppy.hugr import tys, val -from guppy.hugr.hugr import OutPortV -from guppy.hugr.tys import TypeBound - - -extension = GuppyExtension("builtin", []) - - -# We have to define and register the bool type by hand since we want it to be a -# subclass of `SumType` - - -class BoolType(SumType): - """The type of booleans.""" - - linear = False - name = "bool" - - def __init__(self) -> None: - # Hugr bools are encoded as Sum((), ()) - super().__init__([TupleType([]), TupleType([])]) - - @staticmethod - def build(*args: GuppyType, node: Union[ast.Name, ast.Subscript]) -> GuppyType: - if len(args) > 0: - raise GuppyError("Type `bool` is not parametric", node) - return BoolType() - - def __str__(self) -> str: - return "bool" - - def __eq__(self, other: Any) -> bool: - return isinstance(other, BoolType) - - -extension.register_type("bool", BoolType) - - -INT_WIDTH = 6 # 2^6 = 64 bit - -IntType: type[GuppyType] = extension.new_type( - name="int", - hugr_repr=tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=INT_WIDTH)], - bound=TypeBound.Eq, - ), -) - -FloatType: type[GuppyType] = extension.new_type( - name="float", - hugr_repr=tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=TypeBound.Copyable, - ), -) - -StringType: type[GuppyType] = extension.new_type( - name="str", - hugr_repr=tys.Opaque( - extension="TODO", # String hugr extension doesn't exist yet - id="string", - args=[], - bound=TypeBound.Eq, - ), -) - - -class ConstIntS(BaseModel): - """Hugr representation of signed integers in the arithmetic extension.""" - - c: Literal["ConstIntS"] = "ConstIntS" - log_width: int - value: int - - -class ConstF64(BaseModel): - """Hugr representation of floats in the arithmetic extension.""" - - c: Literal["ConstF64"] = "ConstF64" - value: float - - -def bool_value(b: bool) -> val.Value: - """Returns the Hugr representation of a boolean value.""" - return val.Sum(tag=int(b), value=val.Tuple(vs=[])) - - -def int_value(i: int) -> val.Value: - """Returns the Hugr representation of an integer value.""" - return val.Prim(val=val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),))) - - -def float_value(f: float) -> val.Value: - """Returns the Hugr representation of a float value.""" - return val.Prim(val=val.ExtensionVal(c=(ConstF64(value=f),))) - - -class CallableCompiler(CallCompiler): - """Call compiler for the builtin `callable` function""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - check_num_args(1, len(args), self.node) - [arg] = args - is_callable = ( - isinstance(arg.ty, FunctionType) - or self.globals.get_instance_func(arg.ty, "__call__") is not None - ) - const = self.graph.add_constant(bool_value(is_callable), BoolType()).out_port(0) - return [self.graph.add_load_constant(const).out_port(0)] - - -class BuiltinCompiler(CallCompiler): - """Call compiler for builtin functions that call out to dunder instance methods""" - - dunder_name: str - num_args: int - - def __init__(self, dunder_name: str, num_args: int = 1): - self.dunder_name = dunder_name - self.num_args = num_args - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - check_num_args(self.num_args, len(args), self.node) - [arg] = args - func = self.globals.get_instance_func(arg.ty, self.dunder_name) - if func is None: - raise GuppyTypeError( - f"Builtin function `{self.func.name}` is not defined for argument of " - "type `{arg.ty}`", - self.node.args[0] if isinstance(self.node, ast.Call) else self.node, - ) - return func.compile_call(args, self.dfg, self.graph, self.globals, self.node) - - -extension.new_func("abs", BuiltinCompiler("__abs__"), higher_order_value=False) -extension.new_func("bool", BuiltinCompiler("__bool__"), higher_order_value=False) -extension.new_func("callable", CallableCompiler(), higher_order_value=False) -extension.new_func("divmod", BuiltinCompiler("__divmod__"), higher_order_value=False) -extension.new_func("float", BuiltinCompiler("__float__"), higher_order_value=False) -extension.new_func("int", BuiltinCompiler("__int__"), higher_order_value=False) -extension.new_func("len", BuiltinCompiler("__len__"), higher_order_value=False) -extension.new_func( - "pow", BuiltinCompiler("__pow__", num_args=2), higher_order_value=False -) -extension.new_func("repr", BuiltinCompiler("__repr__"), higher_order_value=False) -extension.new_func("round", BuiltinCompiler("__round__"), higher_order_value=False) -extension.new_func("str", BuiltinCompiler("__str__"), higher_order_value=False) - - -# Python builtins that are not supported yet -extension.new_func("aiter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("all", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("anext", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("any", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("ascii", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("aiter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("bin", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("breakpoint", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("bytearray", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("bytes", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("chr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("classmethod", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("compile", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("complex", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("delattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("dict", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("dir", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("enumerate", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("eval", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("exec", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("filter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("format", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("frozenset", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("getattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("globals", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("hasattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("hash", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("help", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("hex", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("id", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("input", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("isinstance", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("issubclass", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("iter", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("list", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("locals", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("map", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("max", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("memoryview", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("min", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("next", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("object", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("oct", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("open", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("ord", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("print", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("property", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("range", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("reversed", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("set", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("setattr", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("slice", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("sorted", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("staticmethod", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("sum", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("super", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("tuple", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("type", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("vars", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("zip", NotImplementedCompiler(), higher_order_value=False) -extension.new_func("__import__", NotImplementedCompiler(), higher_order_value=False) diff --git a/guppy/prelude/float.py b/guppy/prelude/float.py deleted file mode 100644 index e59369eb..00000000 --- a/guppy/prelude/float.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Guppy standard extension for float operations.""" - -# mypy: disable-error-code=empty-body - -from guppy.prelude import builtin -from guppy.prelude.builtin import IntType, FloatType, float_value -from guppy.compiler_base import CallCompiler -from guppy.extension import ( - GuppyExtension, - OpCompiler, - Reversed, - NotImplementedCompiler, - IdOpCompiler, -) -from guppy.hugr import ops -from guppy.hugr.hugr import OutPortV - - -class FloatOpCompiler(OpCompiler): - """Compiler for calls that can be implemented via a single Hugr float op""" - - def __init__(self, op_name: str, extension: str = "arithmetic.float"): - super().__init__(ops.CustomOp(extension=extension, op_name=op_name, args=[])) - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - args = [ - self.graph.add_node( - ops.CustomOp(extension="arithmetic.conversions", op_name="convert_s"), - inputs=[arg], - parent=self.parent, - ).add_out_port(FloatType()) - if isinstance(arg.ty, IntType) - else arg - for arg in args - ] - return super().compile(args) - - -class BoolCompiler(CallCompiler): - """Compiler for the `__bool__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: bool(x) = (x != 0.0) - zero_const = self.graph.add_constant(float_value(0.0), FloatType(), self.parent) - zero = self.graph.add_load_constant(zero_const.out_port(0), self.parent) - return __ne__.compile_call( - [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node - ) - - -class FloordivCompiler(CallCompiler): - """Compiler for the `__floordiv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: floordiv(x, y) = floor(truediv(x, y)) - [div] = __truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [floor] = __floor__.compile_call( - [div], self.dfg, self.graph, self.globals, self.node - ) - return [floor] - - -class ModCompiler(CallCompiler): - """Compiler for the `__mod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: mod(x, y) = x - (x // y) * y - [div] = __floordiv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [mul] = __mul__.compile_call( - [div, args[1]], self.dfg, self.graph, self.globals, self.node - ) - [sub] = __sub__.compile_call( - [args[0], mul], self.dfg, self.graph, self.globals, self.node - ) - return [sub] - - -class DivmodCompiler(CallCompiler): - """Compiler for the `__divmod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # We have: divmod(x, y) = (div(x, y), mod(x, y)) - [div] = __truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [mod] = __mod__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - return [self.graph.add_make_tuple([div, mod], self.parent).out_port(0)] - - -extension = GuppyExtension("float", [builtin]) - - -@extension.func(FloatOpCompiler("fabs"), instance=FloatType) -def __abs__(self: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fadd"), instance=FloatType) -def __add__(self: float, other: float) -> float: - ... - - -@extension.func(BoolCompiler(), instance=FloatType) -def __bool__(self: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("fceil"), instance=FloatType) -def __ceil__(self: float) -> float: - ... - - -@extension.func(DivmodCompiler(), instance=FloatType) -def __divmod__(self: float, other: float) -> tuple[float, float]: - ... - - -@extension.func(FloatOpCompiler("feq"), instance=FloatType) -def __eq__(self: float, other: float) -> bool: - ... - - -@extension.func(IdOpCompiler(), instance=FloatType) -def __float__(self: float) -> float: - ... - - -@extension.func(FloatOpCompiler("ffloor"), instance=FloatType) -def __floor__(self: float) -> float: - ... - - -@extension.func(FloordivCompiler(), instance=FloatType) -def __floordiv__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fge"), instance=FloatType) -def __ge__(self: float, other: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("fgt"), instance=FloatType) -def __gt__(self: float, other: float) -> bool: - ... - - -@extension.func( - FloatOpCompiler("trunc_s", "arithmetic.conversions"), instance=FloatType -) -def __int__(self: float, other: float) -> int: - ... - - -@extension.func(FloatOpCompiler("fle"), instance=FloatType) -def __le__(self: float, other: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("flt"), instance=FloatType) -def __lt__(self: float, other: float) -> bool: - ... - - -@extension.func(ModCompiler(), instance=FloatType) -def __mod__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fmul"), instance=FloatType) -def __mul__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fne"), instance=FloatType) -def __ne__(self: float, other: float) -> bool: - ... - - -@extension.func(FloatOpCompiler("fneg"), instance=FloatType) -def __neg__(self: float, other: float) -> float: - ... - - -@extension.func(IdOpCompiler(), instance=FloatType) -def __pos__(self: float) -> float: - ... - - -@extension.func(NotImplementedCompiler(), instance=FloatType) # TODO -def __pow__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(FloatOpCompiler("fadd")), instance=FloatType) -def __radd__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(DivmodCompiler()), instance=FloatType) -def __rdivmod__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(ModCompiler()), instance=FloatType) -def __rmod__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(FloatOpCompiler("fmul")), instance=FloatType) -def __rmul__(self: float, other: float) -> float: - ... - - -@extension.func(NotImplementedCompiler(), instance=FloatType) # TODO -def __round__(self: float) -> float: - ... - - -@extension.func(Reversed(NotImplementedCompiler()), instance=FloatType) # TODO -def __rpow__(self: float, other: float) -> float: - ... - - -@extension.func(Reversed(FloatOpCompiler("fsub")), instance=FloatType) -def __rsub__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fdiv"), instance=FloatType) -def __rtruediv__(self: float, other: float) -> float: - ... - - -@extension.func(NotImplementedCompiler(), instance=FloatType) # TODO -def __str__(self: float) -> str: - ... - - -@extension.func(FloatOpCompiler("fsub"), instance=FloatType) -def __sub__(self: float, other: float) -> float: - ... - - -@extension.func(FloatOpCompiler("fdiv"), instance=FloatType) -def __truediv__(self: float, other: float) -> float: - ... - - -@extension.func( - FloatOpCompiler("trunc_s", "arithmetic.conversions"), instance=FloatType -) -def __trunc__(self: float, other: float) -> int: - ... diff --git a/guppy/prelude/integer.py b/guppy/prelude/integer.py deleted file mode 100644 index 182884aa..00000000 --- a/guppy/prelude/integer.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Guppy standard extension for int operations.""" - -# mypy: disable-error-code=empty-body - -from guppy.compiler_base import CallCompiler -from guppy.expression import type_check_call -from guppy.guppy_types import FunctionType -from guppy.hugr.hugr import OutPortV -from guppy.prelude import builtin -from guppy.prelude.builtin import IntType, INT_WIDTH, FloatType -from guppy.extension import ( - GuppyExtension, - OpCompiler, - Reversed, - NotImplementedCompiler, - IdOpCompiler, -) -from guppy.hugr import ops, tys - - -class IntOpCompiler(OpCompiler): - """Compiler for calls that can be implemented via a single Hugr integer op""" - - def __init__(self, op_name: str, ext: str = "arithmetic.int", num_params: int = 1): - super().__init__( - ops.CustomOp( - extension=ext, - op_name=op_name, - args=num_params * [tys.BoundedNatArg(n=INT_WIDTH)], - ) - ) - - -class TruedivCompiler(CallCompiler): - """Compiler for the `__truediv__` function""" - - signature: FunctionType = FunctionType([IntType(), IntType()], [FloatType()]) - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - # Compile `truediv` using float arithmetic - import guppy.prelude.float - - type_check_call(self.signature, args, self.node) - [left, right] = args - [left] = __float__.compile_call( - [left], self.dfg, self.graph, self.globals, self.node - ) - [right] = __float__.compile_call( - [right], self.dfg, self.graph, self.globals, self.node - ) - return guppy.prelude.float.__truediv__.compile_call( - [left, right], self.dfg, self.graph, self.globals, self.node - ) - - -extension = GuppyExtension("integer", dependencies=[builtin]) - - -# TODO: Maybe wrong?? (signed vs unsigned!) -@extension.func(IntOpCompiler("iabs"), instance=IntType) -def __abs__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("iadd"), instance=IntType) -def __add__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("iand"), instance=IntType) -def __and__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("itobool"), instance=IntType) -def __bool__(self: int) -> bool: - ... - - -@extension.func(OpCompiler(ops.Noop(ty=IntType().to_hugr())), instance=IntType) -def __ceil__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("idivmod_s", num_params=2), instance=IntType) -def __divmod__(self: int, other: int) -> tuple[int, int]: - ... - - -@extension.func(IntOpCompiler("ieq"), instance=IntType) -def __eq__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("convert_s", "arithmetic.conversions"), instance=IntType) -def __float__(self: int) -> float: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __floor__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("idiv_s", num_params=2), instance=IntType) -def __floordiv__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ige_s"), instance=IntType) -def __ge__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("igt_s"), instance=IntType) -def __gt__(self: int, other: int) -> bool: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __int__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("inot"), instance=IntType) -def __invert__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("ile_s"), instance=IntType) -def __le__(self: int, other: int) -> bool: - ... - - -@extension.func( - IntOpCompiler("ishl", num_params=2), instance=IntType -) # TODO: broken (RHS is unsigned) -def __lshift__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ilt_s"), instance=IntType) -def __lt__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("imod_s", num_params=2), instance=IntType) -def __mod__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("imul"), instance=IntType) -def __mul__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ine"), instance=IntType) -def __ne__(self: int, other: int) -> bool: - ... - - -@extension.func(IntOpCompiler("ineg"), instance=IntType) -def __neg__(self: int) -> int: - ... - - -@extension.func(IntOpCompiler("ior"), instance=IntType) -def __or__(self: int, other: int) -> int: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __pos__(self: int) -> int: - ... - - -@extension.func(NotImplementedCompiler(), instance=IntType) # TODO -def __pow__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("iadd")), instance=IntType) -def __radd__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("iand")), instance=IntType) -def __rand__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("idivmod_s", num_params=2)), instance=IntType) -def __rdivmod__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("idiv_s", num_params=2)), instance=IntType) -def __rfloordiv__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("ishl", num_params=2)), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rlshift__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("imod_s", num_params=2)), instance=IntType) -def __rmod__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("imul")), instance=IntType) -def __rmul__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(IntOpCompiler("ior")), instance=IntType) -def __ror__(self: int, other: int) -> int: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __round__(self: int) -> int: - ... - - -@extension.func(Reversed(NotImplementedCompiler()), instance=IntType) # TODO -def __rpow__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("ishr", num_params=2)), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rrshift__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("ishr", num_params=2)), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rshift__(self: int, other: int) -> int: - ... - - -@extension.func( - Reversed(IntOpCompiler("isub")), instance=IntType -) # TODO: broken (RHS is unsigned) -def __rsub__(self: int, other: int) -> int: - ... - - -@extension.func(Reversed(TruedivCompiler()), instance=IntType) -def __rtruediv__(self: int, other: int) -> float: - ... - - -@extension.func(Reversed(IntOpCompiler("ixor")), instance=IntType) -def __rxor__(self: int, other: int) -> int: - ... - - -@extension.func(NotImplementedCompiler(), instance=IntType) # TODO -def __str__(self: int) -> str: - ... - - -@extension.func(IntOpCompiler("isub"), instance=IntType) -def __sub__(self: int, other: int) -> int: - ... - - -@extension.func(TruedivCompiler(), instance=IntType) -def __truediv__(self: int, other: int) -> float: - ... - - -@extension.func(IdOpCompiler(), instance=IntType) -def __trunc__(self: int, other: int) -> int: - ... - - -@extension.func(IntOpCompiler("ixor"), instance=IntType) -def __xor__(self: int, other: int) -> int: - ... diff --git a/guppy/statement.py b/guppy/statement.py deleted file mode 100644 index 502f0207..00000000 --- a/guppy/statement.py +++ /dev/null @@ -1,177 +0,0 @@ -import ast -from typing import Sequence - -from guppy.ast_util import AstVisitor, AstNode, set_location_from -from guppy.cfg.bb import BB, NestedFunctionDef -from guppy.compiler_base import ( - CompilerBase, - DFContainer, - Variable, - return_var, - Globals, -) -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.expression import ExpressionCompiler -from guppy.guppy_types import TupleType, TypeRow, GuppyType -from guppy.hugr.hugr import OutPortV, Hugr - - -class StatementCompiler(CompilerBase, AstVisitor[None]): - """A compiler for non-control-flow statements occurring in a basic block. - - Control-flow statements like loops or if-statements are not handled by this - compiler. They should be turned into a CFG made up of multiple simple basic blocks. - """ - - expr_compiler: ExpressionCompiler - - bb: BB - dfg: DFContainer - return_tys: list[GuppyType] - - def __init__(self, graph: Hugr, globals: Globals): - super().__init__(graph, globals) - self.expr_compiler = ExpressionCompiler(graph, globals) - - def compile_stmts( - self, - stmts: Sequence[ast.stmt], - bb: BB, - dfg: DFContainer, - return_tys: list[GuppyType], - ) -> DFContainer: - """Compiles a list of basic statements into a dataflow node. - - Note that the `dfg` is mutated in-place. After compilation, the DFG will also - contain all variables that are assigned in the given list of statements. - """ - self.bb = bb - self.dfg = dfg - self.return_tys = return_tys - for s in stmts: - self.visit(s) - return self.dfg - - def visit_Assign(self, node: ast.Assign) -> None: - if len(node.targets) > 1: - # This is the case for assignments like `a = b = 1` - raise GuppyError("Multi assignment not supported", node) - target = node.targets[0] - row = self.expr_compiler.compile_row(node.value, self.dfg) - if len(row) == 0: - # In Python it's fine to assign a void return with the variable being bound - # to `None` afterward. At the moment, we don't have a `None` type in Guppy, - # so we raise an error for now. - # TODO: Think about this. Maybe we should uniformly treat `None` as the - # empty tuple? - raise GuppyError("Cannot unpack empty row") - assert len(row) > 0 - - # Helper function to unpack the row based on the LHS pattern - def unpack(pattern: AstNode, ports: list[OutPortV]) -> None: - # Easiest case is if the LHS pattern is a single variable. Note we - # implicitly pack the row into a tuple if it has more than one element. - # I.e. `x = 1, 2` works just like `x = (1, 2)`. - if isinstance(pattern, ast.Name): - port = ( - ports[0] - if len(ports) == 1 - else self.graph.add_make_tuple( - inputs=ports, parent=self.dfg.node - ).out_port(0) - ) - # Check if we override an unused linear variable - x = pattern.id - if x in self.dfg: - var = self.dfg[x] - if var.ty.linear and var.used is None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` " - "is not used", - var.defined_at, - ) - self.dfg[x] = Variable(x, port, node) - # The only other thing we support right now are tuples - elif isinstance(pattern, ast.Tuple): - if len(ports) == 1 and isinstance(ports[0].ty, TupleType): - ports = list( - self.graph.add_unpack_tuple( - input_tuple=ports[0], parent=self.dfg.node - ).out_ports - ) - n, m = len(pattern.elts), len(ports) - if n != m: - raise GuppyTypeError( - f"{'Too many' if n < m else 'Not enough'} " - f"values to unpack (expected {n}, got {m})", - node, - ) - for pat, port in zip(pattern.elts, ports): - unpack(pat, [port]) - # TODO: Python also supports assignments like `[a, b] = [1, 2]` or - # `a, *b = ...`. The former would require some runtime checks but - # the latter should be easier to do (unpack and repack the rest). - else: - raise GuppyError("Assignment pattern not supported", pattern) - - unpack(target, row) - - def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - # TODO: Figure out what to do with type annotations - raise NotImplementedError() - - def visit_AugAssign(self, node: ast.AugAssign) -> None: - bin_op = ast.BinOp(left=node.target, op=node.op, right=node.value) - set_location_from(bin_op, node) - assign = ast.Assign(targets=[node.target], value=bin_op) - set_location_from(assign, node) - self.visit_Assign(assign) - - def visit_Expr(self, node: ast.Expr) -> None: - self.expr_compiler.compile_row(node.value, self.dfg) - - def visit_Return(self, node: ast.Return) -> None: - if node.value is None: - row = [] - else: - port = self.expr_compiler.compile(node.value, self.dfg) - # Top-level tuples are unpacked, i.e. turned into a row - if isinstance(port.ty, TupleType): - unpack = self.graph.add_unpack_tuple(port, self.dfg.node) - row = [unpack.out_port(i) for i in range(len(port.ty.element_types))] - else: - row = [port] - - tys = [p.ty for p in row] - if tys != self.return_tys: - raise GuppyTypeError( - f"Return type mismatch: expected `{TypeRow(self.return_tys)}`, " - f"got `{TypeRow(tys)}`", - node.value, - ) - - # We turn returns into assignments of dummy variables, i.e. the statement - # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`. - for i, port in enumerate(row): - name = return_var(i) - self.dfg[name] = Variable(name, port, node) - - def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: - from guppy.function import FunctionDefCompiler - - func = FunctionDefCompiler(self.graph, self.globals).compile_local( - node, self.dfg, self.bb - ) - self.dfg[node.name] = Variable(node.name, func.port, node) - - def visit_If(self, node: ast.If) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_While(self, node: ast.While) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_Break(self, node: ast.Break) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_Continue(self, node: ast.Continue) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") From 18222ed38d18f5c3e61af5d5704a1d0b6651b347 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:01:10 +0000 Subject: [PATCH 11/77] Fix tests --- tests/error/linear_errors/branch_use.py | 10 ++-- tests/error/linear_errors/break_unused.py | 10 ++-- tests/error/linear_errors/continue_unused.py | 10 ++-- tests/error/linear_errors/copy.py | 6 +-- tests/error/linear_errors/if_both_unused.py | 8 +-- .../linear_errors/if_both_unused_reassign.py | 8 +-- tests/error/linear_errors/reassign_unused.py | 8 +-- .../linear_errors/reassign_unused_tuple.py | 8 +-- tests/error/linear_errors/unused.py | 6 +-- .../error/linear_errors/unused_same_block.py | 6 +-- .../nested_errors/different_types_if.err | 2 +- tests/error/nested_errors/linear_capture.py | 6 +-- tests/error/type_errors/and_not_bool_left.py | 6 +-- tests/error/type_errors/and_not_bool_right.py | 6 +-- tests/error/type_errors/fun_ty_mismatch_1.err | 2 +- tests/error/type_errors/fun_ty_mismatch_2.err | 2 +- tests/error/type_errors/if_expr_not_bool.py | 6 +-- tests/error/type_errors/if_not_bool.py | 6 +-- tests/error/type_errors/not_not_bool.py | 6 +-- tests/error/type_errors/or_not_bool_left.py | 6 +-- tests/error/type_errors/or_not_bool_right.py | 6 +-- tests/error/type_errors/return_mismatch.err | 2 +- tests/error/type_errors/while_not_bool.py | 6 +-- tests/error/util.py | 24 +++++---- tests/hugr/test_dummy_nodes.py | 18 +++---- tests/hugr/test_ports.py | 2 +- tests/integration/test_arithmetic.py | 4 +- tests/integration/test_basic.py | 9 ++-- tests/integration/test_call.py | 9 ++-- tests/integration/test_functional.py | 2 +- tests/integration/test_higher_order.py | 3 +- tests/integration/test_if.py | 3 +- tests/integration/test_linear.py | 49 ++++++++++--------- tests/integration/test_nested.py | 2 +- tests/integration/test_programs.py | 5 +- tests/integration/test_unused.py | 2 +- tests/integration/test_while.py | 2 +- 37 files changed, 143 insertions(+), 133 deletions(-) diff --git a/tests/error/linear_errors/branch_use.py b/tests/error/linear_errors/branch_use.py index 6ad6b73b..9ba256c1 100644 --- a/tests/error/linear_errors/branch_use.py +++ b/tests/error/linear_errors/branch_use.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,12 +8,12 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... -@guppy(module) +@guppy.declare(module) def measure(q: Qubit) -> bool: ... @@ -26,4 +26,4 @@ def foo(b: bool) -> bool: return False -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/break_unused.py b/tests/error/linear_errors/break_unused.py index 81e18bdf..b66d2ba4 100644 --- a/tests/error/linear_errors/break_unused.py +++ b/tests/error/linear_errors/break_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,12 +8,12 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... -@guppy(module) +@guppy.declare(module) def measure() -> bool: ... @@ -30,4 +30,4 @@ def foo(i: int) -> bool: return b -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/continue_unused.py b/tests/error/linear_errors/continue_unused.py index 57e12b53..4e6194f8 100644 --- a/tests/error/linear_errors/continue_unused.py +++ b/tests/error/linear_errors/continue_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,12 +8,12 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... -@guppy(module) +@guppy.declare(module) def measure() -> bool: ... @@ -30,4 +30,4 @@ def foo(i: int) -> bool: return b -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/copy.py b/tests/error/linear_errors/copy.py index c4d13986..1f726112 100644 --- a/tests/error/linear_errors/copy.py +++ b/tests/error/linear_errors/copy.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -13,4 +13,4 @@ def foo(q: Qubit) -> tuple[Qubit, Qubit]: return q, q -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/if_both_unused.py b/tests/error/linear_errors/if_both_unused.py index 459a3f5d..4ce0fe88 100644 --- a/tests/error/linear_errors/if_both_unused.py +++ b/tests/error/linear_errors/if_both_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -22,4 +22,4 @@ def foo(b: bool) -> int: return 42 -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/if_both_unused_reassign.py b/tests/error/linear_errors/if_both_unused_reassign.py index 3858e087..dcd83e66 100644 --- a/tests/error/linear_errors/if_both_unused_reassign.py +++ b/tests/error/linear_errors/if_both_unused_reassign.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -23,4 +23,4 @@ def foo(b: bool) -> Qubit: return q -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/reassign_unused.py b/tests/error/linear_errors/reassign_unused.py index 402ff729..7a936455 100644 --- a/tests/error/linear_errors/reassign_unused.py +++ b/tests/error/linear_errors/reassign_unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -19,4 +19,4 @@ def foo(q: Qubit) -> Qubit: return q -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/reassign_unused_tuple.py b/tests/error/linear_errors/reassign_unused_tuple.py index a888a67f..162e4163 100644 --- a/tests/error/linear_errors/reassign_unused_tuple.py +++ b/tests/error/linear_errors/reassign_unused_tuple.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -8,7 +8,7 @@ module.load(quantum) -@guppy(module) +@guppy.declare(module) def new_qubit() -> Qubit: ... @@ -19,4 +19,4 @@ def foo(q: Qubit) -> tuple[Qubit, Qubit]: return q, r -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/unused.py b/tests/error/linear_errors/unused.py index f45f2141..a22b8f16 100644 --- a/tests/error/linear_errors/unused.py +++ b/tests/error/linear_errors/unused.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -14,4 +14,4 @@ def foo(q: Qubit) -> int: return 42 -module.compile(True) +module.compile() diff --git a/tests/error/linear_errors/unused_same_block.py b/tests/error/linear_errors/unused_same_block.py index ccf8ca0b..55031f6b 100644 --- a/tests/error/linear_errors/unused_same_block.py +++ b/tests/error/linear_errors/unused_same_block.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -15,4 +15,4 @@ def foo(q: Qubit) -> int: return x -module.compile(True) +module.compile() diff --git a/tests/error/nested_errors/different_types_if.err b/tests/error/nested_errors/different_types_if.err index 0426774e..da397a08 100644 --- a/tests/error/nested_errors/different_types_if.err +++ b/tests/error/nested_errors/different_types_if.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:13 12: 13: return bar() ^^^ -GuppyError: Variable `bar` can refer to different types: `None -> int` (at 7:8) vs `None -> bool` (at 10:8) +GuppyError: Variable `bar` can refer to different types: `() -> int` (at 7:8) vs `() -> bool` (at 10:8) diff --git a/tests/error/nested_errors/linear_capture.py b/tests/error/nested_errors/linear_capture.py index 1545c31b..2c89d1fc 100644 --- a/tests/error/nested_errors/linear_capture.py +++ b/tests/error/nested_errors/linear_capture.py @@ -1,6 +1,6 @@ import guppy.prelude.quantum as quantum - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.hugr.tys import Qubit @@ -16,4 +16,4 @@ def bar() -> Qubit: return q -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/and_not_bool_left.py b/tests/error/type_errors/and_not_bool_left.py index 92647c8e..30f3b338 100644 --- a/tests/error/type_errors/and_not_bool_left.py +++ b/tests/error/type_errors/and_not_bool_left.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool, y: bool) -> bool: return x and y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/and_not_bool_right.py b/tests/error/type_errors/and_not_bool_right.py index 23637899..deb16df7 100644 --- a/tests/error/type_errors/and_not_bool_right.py +++ b/tests/error/type_errors/and_not_bool_right.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: bool, y: NonBool) -> bool: return x and y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/fun_ty_mismatch_1.err b/tests/error/type_errors/fun_ty_mismatch_1.err index 3e4cb8ef..a21337a1 100644 --- a/tests/error/type_errors/fun_ty_mismatch_1.err +++ b/tests/error/type_errors/fun_ty_mismatch_1.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:11 10: 11: return bar ^^^ -GuppyTypeError: Return type mismatch: expected `int -> int`, got `int -> bool` +GuppyTypeError: Expected return value of type `int -> int`, got `int -> bool` diff --git a/tests/error/type_errors/fun_ty_mismatch_2.err b/tests/error/type_errors/fun_ty_mismatch_2.err index 1f578dbd..ee3f83f4 100644 --- a/tests/error/type_errors/fun_ty_mismatch_2.err +++ b/tests/error/type_errors/fun_ty_mismatch_2.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:11 10: 11: return bar(foo) ^^^ -GuppyTypeError: Expected argument of type `None -> int`, got `int -> int` +GuppyTypeError: Expected argument of type `() -> int`, got `int -> int` diff --git a/tests/error/type_errors/if_expr_not_bool.py b/tests/error/type_errors/if_expr_not_bool.py index da546a40..41648464 100644 --- a/tests/error/type_errors/if_expr_not_bool.py +++ b/tests/error/type_errors/if_expr_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool) -> int: return 1 if x else 0 -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/if_not_bool.py b/tests/error/type_errors/if_not_bool.py index 3233d405..a51125fd 100644 --- a/tests/error/type_errors/if_not_bool.py +++ b/tests/error/type_errors/if_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -14,4 +14,4 @@ def foo(x: NonBool) -> int: return 1 -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/not_not_bool.py b/tests/error/type_errors/not_not_bool.py index f8eea0bd..03e9e083 100644 --- a/tests/error/type_errors/not_not_bool.py +++ b/tests/error/type_errors/not_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool) -> bool: return not x -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/or_not_bool_left.py b/tests/error/type_errors/or_not_bool_left.py index b7a9c2c6..9236e3b0 100644 --- a/tests/error/type_errors/or_not_bool_left.py +++ b/tests/error/type_errors/or_not_bool_left.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: NonBool, y: bool) -> bool: return x or y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/or_not_bool_right.py b/tests/error/type_errors/or_not_bool_right.py index 407f46b3..ebd7c092 100644 --- a/tests/error/type_errors/or_not_bool_right.py +++ b/tests/error/type_errors/or_not_bool_right.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -12,4 +12,4 @@ def foo(x: bool, y: NonBool) -> bool: return x or y -module.compile(True) +module.compile() diff --git a/tests/error/type_errors/return_mismatch.err b/tests/error/type_errors/return_mismatch.err index d7fa70cf..a59582be 100644 --- a/tests/error/type_errors/return_mismatch.err +++ b/tests/error/type_errors/return_mismatch.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:6 5: def foo() -> bool: 6: return 42 ^^ -GuppyTypeError: Return type mismatch: expected `bool`, got `int` +GuppyTypeError: Expected return value of type `bool`, got `int` diff --git a/tests/error/type_errors/while_not_bool.py b/tests/error/type_errors/while_not_bool.py index e12ea59c..2fe00a00 100644 --- a/tests/error/type_errors/while_not_bool.py +++ b/tests/error/type_errors/while_not_bool.py @@ -1,6 +1,6 @@ import tests.error.util - -from guppy.compiler import GuppyModule, guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.error.util import NonBool module = GuppyModule("test") @@ -14,4 +14,4 @@ def foo(x: NonBool) -> int: return 0 -module.compile(True) +module.compile() diff --git a/tests/error/util.py b/tests/error/util.py index 8fc76028..a3e39191 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -1,21 +1,22 @@ import importlib.util import pathlib -from typing import Callable, Optional, Any, TypeVar - import pytest -from guppy.compiler import GuppyModule -from guppy.extension import GuppyExtension +from typing import Callable, Optional, Any + from guppy.hugr import tys -from guppy.hugr.hugr import Hugr from guppy.hugr.tys import TypeBound +from guppy.module import GuppyModule +from guppy.hugr.hugr import Hugr + +import guppy.decorator as decorator def guppy(f: Callable[..., Any]) -> Optional[Hugr]: """ Decorator to compile functions outside of modules for testing. """ module = GuppyModule("module") - module.register_func(f) - return module.compile(exit_on_error=True) + module.register_func_def(f) + return module.compile() def run_error_test(file, capsys): @@ -35,5 +36,10 @@ def run_error_test(file, capsys): assert err == exp_err -ext = GuppyExtension("test", []) -NonBool = ext.new_type("NonBool", tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)) +util = GuppyModule("test") + + +@decorator.guppy.type(util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)) +class NonBool: + pass + diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index 7125c500..f42ba4b1 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -1,19 +1,15 @@ -import pytest - -from guppy.error import UndefinedPort, InternalGuppyError -from guppy.guppy_types import FunctionType +from guppy.guppy_types import FunctionType, BoolType, TupleType from guppy.hugr import ops from guppy.hugr.hugr import Hugr -from guppy.prelude.builtin import BoolType, IntType def test_single_dummy(): g = Hugr() - defn = g.add_def(FunctionType([IntType()], [IntType()]), g.root, "test") + defn = g.add_def(FunctionType([BoolType()], [BoolType()]), g.root, "test") dfg = g.add_dfg(defn) - inp = g.add_input([IntType()], dfg).out_port(0) + inp = g.add_input([BoolType()], dfg).out_port(0) dummy = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[IntType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg ) g.add_output([dummy.out_port(0)], parent=dfg) @@ -24,11 +20,11 @@ def test_single_dummy(): def test_unique_names(): g = Hugr() - defn = g.add_def(FunctionType([IntType()], [IntType(), BoolType]), g.root, "test") + defn = g.add_def(FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test") dfg = g.add_dfg(defn) - inp = g.add_input([IntType()], dfg).out_port(0) + inp = g.add_input([BoolType()], dfg).out_port(0) dummy1 = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[IntType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg ) dummy2 = g.add_node( ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py index d981840e..0e1d74e1 100644 --- a/tests/hugr/test_ports.py +++ b/tests/hugr/test_ports.py @@ -1,7 +1,7 @@ import pytest from guppy.error import UndefinedPort, InternalGuppyError -from guppy.prelude.builtin import BoolType +from guppy.guppy_types import BoolType def test_undefined_port(): diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 7ed1c6f1..c4112fd1 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -1,4 +1,4 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy def test_arith_basic(validate): @@ -17,7 +17,7 @@ def const() -> float: validate(const) -def test_ann_assign(validate): +def test_aug_assign(validate): @guppy def add(x: int) -> int: x += 1 diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index d29fdb7f..cde2801a 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -1,5 +1,6 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy from guppy.hugr import ops +from guppy.module import GuppyModule def test_id(validate): @@ -70,9 +71,11 @@ def func_name() -> None: def test_func_decl_name(): - @guppy + module = GuppyModule("test") + + @guppy.declare(module) def func_name() -> None: ... - [def_op] = [n.op for n in func_name.nodes() if isinstance(n.op, ops.FuncDecl)] + [def_op] = [n.op for n in module.compile().nodes() if isinstance(n.op, ops.FuncDecl)] assert def_op.name == "func_name" diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index 2fe6bf75..bad4d824 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -1,4 +1,5 @@ -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule def test_call(validate): @@ -12,7 +13,7 @@ def foo() -> int: def bar() -> int: return foo() - validate(module.compile(exit_on_error=True)) + validate(module.compile()) def test_call_back(validate): @@ -26,7 +27,7 @@ def foo(x: int) -> int: def bar(x: int) -> int: return x - validate(module.compile(exit_on_error=True)) + validate(module.compile()) def test_recursion(validate): @@ -48,7 +49,7 @@ def foo(x: int) -> int: def bar(x: int) -> int: return foo(x) - validate(module.compile(exit_on_error=True)) + validate(module.compile()) diff --git a/tests/integration/test_functional.py b/tests/integration/test_functional.py index 38aa9340..8c462d6b 100644 --- a/tests/integration/test_functional.py +++ b/tests/integration/test_functional.py @@ -1,6 +1,6 @@ import pytest -from guppy.compiler import guppy +from guppy.decorator import guppy from tests.integration.util import functional, _ diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py index b6c4d359..b406c932 100644 --- a/tests/integration/test_higher_order.py +++ b/tests/integration/test_higher_order.py @@ -1,6 +1,7 @@ from typing import Callable -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule def test_basic(validate): diff --git a/tests/integration/test_if.py b/tests/integration/test_if.py index 6ea6cafc..4f09b295 100644 --- a/tests/integration/test_if.py +++ b/tests/integration/test_if.py @@ -1,6 +1,7 @@ import pytest -from guppy.compiler import guppy +from guppy.decorator import guppy +from guppy.module import GuppyModule def test_if_no_else(validate): diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index 71ebcbb5..0d744e06 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -1,8 +1,9 @@ -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule from guppy.prelude.quantum import Qubit import guppy.prelude.quantum as quantum -from guppy.prelude.quantum import h, cx +from guppy.prelude.quantum import h, cx, measure def test_id(validate): @@ -13,7 +14,7 @@ def test_id(validate): def test(q: Qubit) -> Qubit: return q - validate(module.compile(True)) + validate(module.compile()) def test_assign(validate): @@ -26,7 +27,7 @@ def test(q: Qubit) -> Qubit: s = r return s - validate(module.compile(True)) + validate(module.compile()) def test_linear_return_order(validate): @@ -38,18 +39,18 @@ def test_linear_return_order(validate): def test(q: Qubit) -> tuple[Qubit, bool]: return measure(q) - validate(module.compile(True)) + validate(module.compile()) def test_interleave(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def f(q1: Qubit, q2: Qubit) -> tuple[Qubit, Qubit]: ... - @guppy(module) + @guppy.declare(module) def g(q1: Qubit, q2: Qubit) -> tuple[Qubit, Qubit]: ... @@ -60,14 +61,14 @@ def test(a: Qubit, b: Qubit, c: Qubit, d: Qubit) -> tuple[Qubit, Qubit, Qubit, Q b, c = g(b, c) return a, b, c, d - validate(module.compile(True)) + validate(module.compile()) def test_if(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def new() -> Qubit: ... @@ -80,14 +81,14 @@ def test(b: bool) -> Qubit: q = new() return q - validate(module.compile(True)) + validate(module.compile()) def test_if_return(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def new() -> Qubit: ... @@ -100,14 +101,14 @@ def test(b: bool) -> Qubit: q = new() return q - validate(module.compile(True)) + validate(module.compile()) def test_measure(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def measure(q: Qubit) -> bool: ... @@ -116,14 +117,14 @@ def test(q: Qubit, x: int) -> int: b = measure(q) return x - validate(module.compile(True)) + validate(module.compile()) def test_return_call(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def op(q: Qubit) -> Qubit: ... @@ -131,7 +132,7 @@ def op(q: Qubit) -> Qubit: def test(q: Qubit) -> Qubit: return op(q) - validate(module.compile(True)) + validate(module.compile()) def test_while(validate): @@ -145,7 +146,7 @@ def test(q: Qubit, i: int) -> Qubit: q = h(q) return q - validate(module.compile(True)) + validate(module.compile()) def test_while_break(validate): @@ -161,7 +162,7 @@ def test(q: Qubit, i: int) -> Qubit: break return q - validate(module.compile(True)) + validate(module.compile()) def test_while_continue(validate): @@ -177,18 +178,18 @@ def test(q: Qubit, i: int) -> Qubit: q = h(q) return q - validate(module.compile(True)) + validate(module.compile()) def test_while_reset(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def new_qubit() -> Qubit: ... - @guppy(module) + @guppy.declare(module) def measure() -> bool: ... @@ -208,15 +209,15 @@ def test_rus(validate): module = GuppyModule("test") module.load(quantum) - @guppy(module) + @guppy.declare(module) def measure(q: Qubit) -> bool: ... - @guppy(module) + @guppy.declare(module) def qalloc() -> Qubit: ... - @guppy(module) + @guppy.declare(module) def t(q: Qubit) -> Qubit: ... diff --git a/tests/integration/test_nested.py b/tests/integration/test_nested.py index caac4e0d..fe58523f 100644 --- a/tests/integration/test_nested.py +++ b/tests/integration/test_nested.py @@ -1,4 +1,4 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy def test_basic(validate): diff --git a/tests/integration/test_programs.py b/tests/integration/test_programs.py index d445a10e..01c2a6fd 100644 --- a/tests/integration/test_programs.py +++ b/tests/integration/test_programs.py @@ -1,4 +1,5 @@ -from guppy.compiler import guppy, GuppyModule +from guppy.decorator import guppy +from guppy.module import GuppyModule from tests.integration.util import functional, _ @@ -53,4 +54,4 @@ def is_odd(x: int) -> bool: return False return is_even(x - 1) - validate(module.compile(exit_on_error=True)) + validate(module.compile()) diff --git a/tests/integration/test_unused.py b/tests/integration/test_unused.py index 0ad04c84..772e6b60 100644 --- a/tests/integration/test_unused.py +++ b/tests/integration/test_unused.py @@ -1,6 +1,6 @@ import pytest -from guppy.compiler import guppy +from guppy.decorator import guppy """ All sorts of weird stuff is allowed when variables are not used. """ diff --git a/tests/integration/test_while.py b/tests/integration/test_while.py index 1af3939f..28d20dc0 100644 --- a/tests/integration/test_while.py +++ b/tests/integration/test_while.py @@ -1,4 +1,4 @@ -from guppy.compiler import guppy +from guppy.decorator import guppy def test_infinite_loop(validate): From 6db0327538353adf580b81380b9265acefbe7a75 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:03:43 +0000 Subject: [PATCH 12/77] Remove unused imports --- guppy/cfg/bb.py | 2 +- guppy/checker/cfg_checker.py | 3 +-- guppy/checker/core.py | 3 +-- guppy/checker/expr_checker.py | 5 +---- guppy/checker/func_checker.py | 9 +++------ guppy/checker/stmt_checker.py | 4 ++-- guppy/compiler/core.py | 2 +- guppy/compiler/func_compiler.py | 7 ++----- guppy/compiler/stmt_compiler.py | 4 ++-- guppy/custom.py | 2 +- guppy/decorator.py | 2 -- guppy/guppy_types.py | 2 +- guppy/hugr/hugr.py | 4 ++-- guppy/module.py | 6 +++--- 14 files changed, 21 insertions(+), 34 deletions(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index f42df90c..e3543204 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -1,7 +1,7 @@ import ast from abc import ABC from dataclasses import dataclass, field -from typing import Optional, TYPE_CHECKING, Union, Sequence +from typing import Optional, TYPE_CHECKING, Union from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index a8dc25e5..52f94d5a 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -1,10 +1,9 @@ import collections -from copy import copy from dataclasses import dataclass from typing import Sequence from guppy.ast_util import line_col -from guppy.cfg.bb import BB, BBStatement +from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG from guppy.checker.core import Globals, Context diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 461d150c..bcbe47f3 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,12 +1,11 @@ import ast from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import NamedTuple, Optional, Union from guppy.ast_util import AstNode from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType, NoneType, \ BoolType -from guppy.nodes import GlobalCall @dataclass diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 9b2e1ecf..f62fe5a7 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -1,13 +1,10 @@ import ast -from ast import NodeTransformer from typing import Optional, Union, NoReturn, Any -from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type, \ - get_type_opt +from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError from guppy.guppy_types import GuppyType, TupleType, FunctionType, BoolType -from guppy.hugr import val from guppy.nodes import LocalName, GlobalName, LocalCall # Mapping from unary AST op to dunder method and display name diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 87cf9ee6..13109781 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -1,6 +1,5 @@ import ast -from dataclasses import dataclass, field -from typing import Mapping +from dataclasses import dataclass from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc from guppy.cfg.bb import BB @@ -9,10 +8,8 @@ from guppy.checker.cfg_checker import check_cfg, CheckedCFG from guppy.checker.expr_checker import synthesize_call, check_call from guppy.error import GuppyError -from guppy.guppy_types import FunctionType, type_from_ast, NoneType, GuppyType, \ - type_to_row -from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, LocalName, \ - NestedFunctionDef +from guppy.guppy_types import FunctionType, type_from_ast, NoneType, GuppyType +from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef @dataclass diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 8f772b2a..96f891cb 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -1,7 +1,7 @@ import ast from typing import Sequence -from guppy.ast_util import set_location_from, with_loc, AstVisitor +from guppy.ast_util import with_loc, AstVisitor from guppy.cfg.bb import BB, BBStatement from guppy.checker.core import Variable, Context from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker @@ -74,7 +74,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.stmt: def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: if node.value is None: - raise GuppyError(f"Variable declaration is not supported. Assignment is required", node) + raise GuppyError("Variable declaration is not supported. Assignment is required", node) ty = type_from_ast(node.annotation, self.ctx.globals) node.value = self._check_expr(node.value, ty) self._check_assign(node.target, ty, node) diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 1d0269ed..9d93c54a 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Iterator, NamedTuple +from typing import Optional, Iterator from guppy.ast_util import AstNode from guppy.checker.core import Variable, CallableVariable diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index f0f62cb6..9a6e30ca 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -1,15 +1,12 @@ -import ast -from dataclasses import dataclass, field -from typing import Mapping, Sequence +from dataclasses import dataclass from guppy.ast_util import AstNode -from guppy.checker.core import Variable from guppy.checker.func_checker import CheckedFunction, DefinedFunction from guppy.compiler.cfg_compiler import compile_cfg from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer, \ PortVariable from guppy.guppy_types import type_to_row, FunctionType -from guppy.hugr.hugr import Hugr, DFContainingNode, OutPortV, DFContainingVNode +from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index 9ed45c65..a4269c0e 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -1,13 +1,13 @@ import ast from typing import Sequence -from guppy.ast_util import AstVisitor, get_type +from guppy.ast_util import AstVisitor from guppy.checker.cfg_checker import CheckedBB from guppy.compiler.core import CompilerBase, DFContainer, CompiledGlobals, \ PortVariable, return_var from guppy.compiler.expr_compiler import ExprCompiler from guppy.error import InternalGuppyError -from guppy.guppy_types import GuppyType, NoneType, TupleType +from guppy.guppy_types import TupleType from guppy.hugr.hugr import Hugr, OutPortV from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/custom.py b/guppy/custom.py index d267ae07..c888746f 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -3,7 +3,7 @@ from typing import Optional from guppy.ast_util import AstNode, get_type, with_type, with_loc -from guppy.checker.core import CallableVariable, Context, Globals +from guppy.checker.core import Context, Globals from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals diff --git a/guppy/decorator.py b/guppy/decorator.py index b61c06ac..894a94fa 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -1,10 +1,8 @@ -import ast import functools from dataclasses import dataclass from typing import Optional, Union, Callable, Any from guppy.ast_util import AstNode, has_empty_body -from guppy.checker.func_checker import check_signature from guppy.custom import CustomFunction, OpCompiler, DefaultCallChecker, \ CustomCallCompiler, CustomCallChecker, DefaultCallCompiler from guppy.error import GuppyError, pretty_errors diff --git a/guppy/guppy_types.py b/guppy/guppy_types.py index 95c86f1c..4cfa330e 100644 --- a/guppy/guppy_types.py +++ b/guppy/guppy_types.py @@ -1,7 +1,7 @@ import ast from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING, Union +from typing import Optional, Sequence, TYPE_CHECKING import guppy.hugr.tys as tys from guppy.ast_util import AstNode, set_location_from diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 71632980..fa34bbef 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Optional, Iterator, Tuple, Any, Union +from typing import Optional, Iterator, Tuple, Any from dataclasses import field, dataclass import guppy.hugr.ops as ops @@ -14,7 +14,7 @@ FunctionType, SumType, type_to_row, row_to_type, ) -from guppy.hugr import val, tys +from guppy.hugr import val NodeIdx = int PortOffset = int diff --git a/guppy/module.py b/guppy/module.py index 79ddeecf..2b501291 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -8,13 +8,13 @@ from guppy.ast_util import annotate_location, AstNode from guppy.checker.core import Globals, qualified_name from guppy.checker.func_checker import DefinedFunction, check_global_func_def -from guppy.compiler.core import CompiledGlobals, CompiledFunction, DFContainer +from guppy.compiler.core import CompiledGlobals from guppy.compiler.func_compiler import compile_global_func_def, CompiledFunctionDef from guppy.custom import CustomFunction from guppy.declared import DeclaredFunction from guppy.error import GuppyError, pretty_errors -from guppy.guppy_types import GuppyType, type_to_row -from guppy.hugr.hugr import Hugr, Node, VNode, OutPortV +from guppy.guppy_types import GuppyType +from guppy.hugr.hugr import Hugr PyFunc = Callable[..., Any] From 8616f6b835ef12551212d5ba13e9f3af12d02acc Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:05:05 +0000 Subject: [PATCH 13/77] Run formatting --- guppy/ast_util.py | 4 +- guppy/cfg/cfg.py | 4 +- guppy/checker/cfg_checker.py | 44 +++++++++------ guppy/checker/core.py | 18 ++++-- guppy/checker/expr_checker.py | 42 ++++++++++---- guppy/checker/func_checker.py | 22 ++++++-- guppy/checker/stmt_checker.py | 17 ++++-- guppy/compiler/cfg_compiler.py | 30 ++++++---- guppy/compiler/core.py | 12 +++- guppy/compiler/func_compiler.py | 58 +++++++++++++++---- guppy/compiler/stmt_compiler.py | 15 +++-- guppy/custom.py | 35 +++++++++--- guppy/declared.py | 29 ++++++++-- guppy/decorator.py | 54 ++++++++++++++---- guppy/error.py | 1 - guppy/hugr/hugr.py | 14 ++++- guppy/module.py | 41 ++++++++++---- guppy/nodes.py | 25 +++++---- guppy/prelude/_internal.py | 42 ++++++++++---- guppy/prelude/builtins.py | 99 ++++++++++++++++++++++++--------- guppy/prelude/quantum.py | 6 +- 21 files changed, 455 insertions(+), 157 deletions(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 77603220..bfe6bae9 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -118,7 +118,9 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None: annotate_location(node, source, file, line_offset) -def annotate_location(node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True) -> None: +def annotate_location( + node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True +) -> None: setattr(node, "line_offset", line_offset) setattr(node, "file", file) setattr(node, "source", source) diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py index 6bd1e4b0..2648f696 100644 --- a/guppy/cfg/cfg.py +++ b/guppy/cfg/cfg.py @@ -25,7 +25,9 @@ class BaseCFG(Generic[T]): ass_before: Result[DefAssignmentDomain] maybe_ass_before: Result[MaybeAssignmentDomain] - def __init__(self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None): + def __init__( + self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None + ): self.bbs = bbs if entry_bb: self.entry_bb = entry_bb diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 52f94d5a..fdcfb54a 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -49,7 +49,9 @@ def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None: self.output_ty = output_ty -def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedCFG: +def check_cfg( + cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals +) -> CheckedCFG: """Type checks a control-flow graph. Annotates the basic blocks with input and output type signatures and removes @@ -61,18 +63,24 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) # We start by compiling the entry BB checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) - checked_cfg.entry_bb = check_bb(cfg.entry_bb, checked_cfg, inputs, return_ty, globals) + checked_cfg.entry_bb = check_bb( + cfg.entry_bb, checked_cfg, inputs, return_ty, globals + ) compiled = {cfg.entry_bb: checked_cfg.entry_bb} # Visit all control-flow edges in BFS order. We can't just do a normal loop over # all BBs since the input types for a BB are computed by checking a predecessor. # We do BFS instead of DFS to get a better error ordering. queue = collections.deque( - (checked_cfg.entry_bb, i, succ) for i, succ in enumerate(cfg.entry_bb.successors) + (checked_cfg.entry_bb, i, succ) + for i, succ in enumerate(cfg.entry_bb.successors) ) while len(queue) > 0: pred, num_output, bb = queue.popleft() - input_row = [Variable(v.name, v.ty, v.defined_at, None) for v in pred.sig.output_rows[num_output]] + input_row = [ + Variable(v.name, v.ty, v.defined_at, None) + for v in pred.sig.output_rows[num_output] + ] if bb in compiled: # If the BB was already compiled, we just have to check that the signatures @@ -81,9 +89,7 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) else: # Otherwise, check the BB and enqueue its successors checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) - queue += [ - (checked_bb, i, succ) for i, succ in enumerate(bb.successors) - ] + queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)] compiled[bb] = checked_bb # Link up BBs in the checked CFG @@ -94,11 +100,19 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} - checked_cfg.maybe_ass_before = {compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs} + checked_cfg.maybe_ass_before = { + compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs + } return checked_cfg -def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedBB: +def check_bb( + bb: BB, + checked_cfg: CheckedCFG, + inputs: VarRow, + return_ty: GuppyType, + globals: Globals, +) -> CheckedBB: cfg = bb.cfg # For the entry BB we have to separately check that all used variables are @@ -132,9 +146,7 @@ def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyTy f"Variable `{x}` is not defined on all control-flow paths.", use_bb.vars.used[x], ) - raise GuppyError( - f"Variable `{x}` is not defined", use_bb.vars.used[x] - ) + raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x]) # We have to check that used linear variables are not being outputted if x in ctx.locals: @@ -166,7 +178,9 @@ def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyTy ] # Also prepare the successor list so we can fill it in later - checked_bb = CheckedBB(bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs)) + checked_bb = CheckedBB( + bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) + ) checked_bb.successors = [None] * len(bb.successors) # type: ignore checked_bb.branch_pred = bb.branch_pred return checked_bb @@ -193,9 +207,7 @@ def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: v1, v2 = v2, v1 # We shouldn't mention temporary variables (starting with `%`) # in error messages: - ident = ( - "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" - ) + ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" raise GuppyError( f"{ident} can refer to different types: " f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", diff --git a/guppy/checker/core.py b/guppy/checker/core.py index bcbe47f3..77cde03e 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -4,8 +4,14 @@ from typing import NamedTuple, Optional, Union from guppy.ast_util import AstNode -from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType, NoneType, \ - BoolType +from guppy.guppy_types import ( + GuppyType, + FunctionType, + TupleType, + SumType, + NoneType, + BoolType, +) @dataclass @@ -25,11 +31,15 @@ class CallableVariable(ABC, Variable): ty: FunctionType @abstractmethod - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context") -> ast.expr: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" + ) -> ast.expr: """Checks the return type of a function call against a given type.""" @abstractmethod - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: "Context" + ) -> tuple[ast.expr, GuppyType]: """Synthesizes the return type of a function call.""" diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index f62fe5a7..a5297f57 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -52,7 +52,12 @@ def __init__(self, ctx: Context) -> None: self.ctx = ctx self._kind = "expression" - def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Optional[AstNode] = None) -> NoReturn: + def _fail( + self, + expected: GuppyType, + actual: Union[ast.expr, GuppyType], + loc: Optional[AstNode] = None, + ) -> NoReturn: """Raises a type error indicating that the type doesn't match.""" if not isinstance(actual, GuppyType): loc = loc or actual @@ -62,7 +67,9 @@ def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Op f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc ) - def check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + def check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: """Checks an expression against a type. Returns a new desugared expression with type annotations. @@ -108,7 +115,9 @@ def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: node, ty = self.visit(node) return with_type(ty, node), ty - def _check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + def _check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: """Checks an expression against a given type""" return ExprChecker(self.ctx).check(expr, ty, kind) @@ -218,7 +227,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: node.func, ty = self.synthesize(node.func) # First handle direct calls of user-defined functions and extension functions - if isinstance(node.func, GlobalName) and isinstance(node.func.value, CallableVariable): + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): return node.func.value.synthesize_call(node.args, node, self.ctx) # Otherwise, it must be a function as a higher-order value @@ -263,7 +274,9 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: ) -def synthesize_call(func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[list[ast.expr], GuppyType]: +def synthesize_call( + func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context +) -> tuple[list[ast.expr], GuppyType]: """Synthesizes the return type of a function call""" check_num_args(len(func_ty.args), len(args), node) for i, arg in enumerate(args): @@ -271,7 +284,13 @@ def synthesize_call(func_ty: FunctionType, args: list[ast.expr], node: AstNode, return args, func_ty.returns -def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> list[ast.expr]: +def check_call( + func_ty: FunctionType, + args: list[ast.expr], + ty: GuppyType, + node: AstNode, + ctx: Context, +) -> list[ast.expr]: """Checks the return type of a function call against a given type""" args, return_ty = synthesize_call(func_ty, args, node, ctx) if return_ty != ty: @@ -281,7 +300,9 @@ def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node: return args -def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr, GuppyType]: +def to_bool( + node: ast.expr, node_ty: GuppyType, ctx: Context +) -> tuple[ast.expr, GuppyType]: """Tries to turn a node into a bool""" if isinstance(node_ty, BoolType): return node, node_ty @@ -299,12 +320,14 @@ def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr, if not isinstance(return_ty, BoolType): raise GuppyTypeError( f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", - node + node, ) return call, return_ty -def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Optional[GuppyType]: +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals +) -> Optional[GuppyType]: """Turns a primitive Python value into a Guppy type. Returns `None` if the Python value cannot be represented in Guppy. @@ -316,4 +339,3 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Opti if isinstance(v, float): return globals.types["float"].build(node=node) return None - diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 13109781..facd10b6 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -15,20 +15,27 @@ @dataclass class DefinedFunction(CallableVariable): """A user-defined function""" + ty: FunctionType defined_at: ast.FunctionDef @staticmethod - def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DefinedFunction": + def from_ast( + func_def: ast.FunctionDef, name: str, globals: Globals + ) -> "DefinedFunction": ty = check_signature(func_def, globals) return DefinedFunction(name, ty, func_def, None) - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> GlobalCall: # Use default implementation from the expression checker args = check_call(self.ty, args, ty, node, ctx) return GlobalCall(func=self, args=args) - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty = synthesize_call(self.ty, args, node, ctx) return GlobalCall(func=self, args=args), ty @@ -57,7 +64,9 @@ def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFun return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) -def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) -> CheckedNestedFunctionDef: +def check_nested_func_def( + func_def: NestedFunctionDef, bb: BB, ctx: Context +) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" func_ty = check_signature(func_def, ctx.globals) assert func_ty.arg_names is not None @@ -102,7 +111,10 @@ def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) -> ) # Construct inputs for checking the body CFG - inputs = list(captured.values()) + [Variable(x, ty, func_def.args.args[i], None) for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args))] + inputs = list(captured.values()) + [ + Variable(x, ty, func_def.args.args[i], None) + for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) + ] globals = ctx.globals # Check if the body contains a recursive occurrence of the function name diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 96f891cb..5030c649 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -11,7 +11,6 @@ class StmtChecker(AstVisitor[BBStatement]): - ctx: Context bb: BB return_ty: GuppyType @@ -27,7 +26,9 @@ def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: return ExprSynthesizer(self.ctx).synthesize(node) - def _check_expr(self, node: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + def _check_expr( + self, node: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: return ExprChecker(self.ctx).check(node, ty, kind) def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: @@ -74,14 +75,18 @@ def visit_Assign(self, node: ast.Assign) -> ast.stmt: def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: if node.value is None: - raise GuppyError("Variable declaration is not supported. Assignment is required", node) + raise GuppyError( + "Variable declaration is not supported. Assignment is required", node + ) ty = type_from_ast(node.annotation, self.ctx.globals) node.value = self._check_expr(node.value, ty) self._check_assign(node.target, ty, node) return node def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: - bin_op = with_loc(node, ast.BinOp(left=node.target, op=node.op, right=node.value)) + bin_op = with_loc( + node, ast.BinOp(left=node.target, op=node.op, right=node.value) + ) assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) return self.visit_Assign(assign) @@ -103,7 +108,9 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: from guppy.checker.func_checker import check_nested_func_def func_def = check_nested_func_def(node, self.bb, self.ctx) - self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def, None) + self.ctx.locals[func_def.name] = Variable( + func_def.name, func_def.ty, func_def, None + ) return func_def def visit_If(self, node: ast.If) -> None: diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index bd31be7e..7446af8e 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -3,15 +3,22 @@ from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature from guppy.checker.core import Variable -from guppy.compiler.core import CompiledGlobals, is_return_var, DFContainer, return_var, \ - PortVariable +from guppy.compiler.core import ( + CompiledGlobals, + is_return_var, + DFContainer, + return_var, + PortVariable, +) from guppy.compiler.expr_compiler import ExprCompiler from guppy.compiler.stmt_compiler import StmtCompiler from guppy.guppy_types import TupleType, SumType, type_to_row from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV -def compile_cfg(cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals) -> None: +def compile_cfg( + cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals +) -> None: """Compiles a CFG to Hugr.""" insert_return_vars(cfg) @@ -23,7 +30,9 @@ def compile_cfg(cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlo graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) -def compile_bb(bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals) -> CFNode: +def compile_bb( + bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals +) -> CFNode: """Compiles a single basic block to Hugr.""" inputs = sort_vars(bb.sig.input_row) @@ -100,7 +109,10 @@ def insert_return_vars(cfg: CheckedCFG) -> None: `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are correctly outputted. """ - return_vars = [Variable(return_var(i), ty, None, None) for i, ty in enumerate(type_to_row(cfg.output_ty))] + return_vars = [ + Variable(return_var(i), ty, None, None) + for i, ty in enumerate(type_to_row(cfg.output_ty)) + ] # Before patching, the exit BB shouldn't take any inputs assert len(cfg.exit_bb.sig.input_row) == 0 cfg.exit_bb.sig = Signature(return_vars, cfg.exit_bb.sig.output_rows) @@ -123,14 +135,13 @@ def choose_vars_for_pred( assert len(pred.ty.element_types) == len(output_vars) tuples = [ graph.add_make_tuple( - inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], parent=dfg.node + inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], + parent=dfg.node, ).out_port(0) for vs in output_vars ] tys = [t.ty for t in tuples] - conditional = graph.add_conditional( - cond_input=pred, inputs=tuples, parent=dfg.node - ) + conditional = graph.add_conditional(cond_input=pred, inputs=tuples, parent=dfg.node) for i, ty in enumerate(tys): case = graph.add_case(conditional) inp = graph.add_input(output_tys=tys, parent=case).out_port(i) @@ -158,4 +169,3 @@ def sort_vars(row: VarRow) -> list[Variable]: This determines the order in which they are outputted from a BB. """ return sorted(row, key=functools.cmp_to_key(compare_var)) - diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 9d93c54a..2e616e5a 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -17,7 +17,13 @@ class PortVariable(Variable): port: OutPortV - def __init__(self, name: str, port: OutPortV, defined_at: Optional[AstNode], used: Optional[AstNode] = None): + def __init__( + self, + name: str, + port: OutPortV, + defined_at: Optional[AstNode], + used: Optional[AstNode] = None, + ) -> None: super().__init__(name, port.ty, defined_at, used) object.__setattr__(self, "port", port) @@ -26,7 +32,9 @@ class CompiledVariable(ABC, Variable): """Abstract base class for compiled global module-level variables.""" @abstractmethod - def load(self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode) -> OutPortV: + def load( + self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode + ) -> OutPortV: """Loads the variable as a value into a local dataflow graph.""" diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index 9a6e30ca..ee032cc9 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -3,8 +3,12 @@ from guppy.ast_util import AstNode from guppy.checker.func_checker import CheckedFunction, DefinedFunction from guppy.compiler.cfg_compiler import compile_cfg -from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer, \ - PortVariable +from guppy.compiler.core import ( + CompiledFunction, + CompiledGlobals, + DFContainer, + PortVariable, +) from guppy.guppy_types import type_to_row, FunctionType from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef @@ -14,15 +18,29 @@ class CompiledFunctionDef(DefinedFunction, CompiledFunction): node: DFContainingVNode - def load(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) - def compile_call(self, args: list[OutPortV], dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> list[OutPortV]: + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] -def compile_global_func_def(func: CheckedFunction, def_node: DFContainingVNode, graph: Hugr, globals: CompiledGlobals) -> CompiledFunctionDef: +def compile_global_func_def( + func: CheckedFunction, + def_node: DFContainingVNode, + graph: Hugr, + globals: CompiledGlobals, +) -> CompiledFunctionDef: """Compiles a top-level function definition to Hugr.""" def_input = graph.add_input(parent=def_node) cfg_node = graph.add_cfg( @@ -31,12 +49,20 @@ def compile_global_func_def(func: CheckedFunction, def_node: DFContainingVNode, compile_cfg(func.cfg, graph, cfg_node, globals) # Add output node for the cfg - graph.add_output(inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], parent=def_node) + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) -def compile_local_func_def(func: CheckedNestedFunctionDef, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals) -> PortVariable: +def compile_local_func_def( + func: CheckedNestedFunctionDef, + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, +) -> PortVariable: """Compiles a local (nested) function definition to Hugr.""" assert func.ty.arg_names is not None @@ -59,24 +85,32 @@ def compile_local_func_def(func: CheckedNestedFunctionDef, dfg: DFContainer, gra # variable if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0) - partial = graph.add_partial(loaded, [def_input.out_port(i) for i in range(len(captured))], def_node) + partial = graph.add_partial( + loaded, [def_input.out_port(i) for i in range(len(captured))], def_node + ) input_ports.append(partial.out_port(0)) func.cfg.input_tys.append(func.ty) else: # Otherwise, we treat the function like a normal global variable - globals = globals | {func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node)} + globals = globals | { + func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node) + } # Compile the CFG cfg_node = graph.add_cfg(def_node, inputs=input_ports) compile_cfg(func.cfg, graph, cfg_node, globals) # Add output node for the cfg - graph.add_output(inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], parent=def_node) + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) # Finally, load the function into the local data-flow graph loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) if len(captured) > 0: - loaded = graph.add_partial(loaded, [dfg[v.name].port for v in captured], dfg.node).out_port(0) + loaded = graph.add_partial( + loaded, [dfg[v.name].port for v in captured], dfg.node + ).out_port(0) return PortVariable(func.name, loaded, func) - diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index a4269c0e..61b80dff 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -3,8 +3,13 @@ from guppy.ast_util import AstVisitor from guppy.checker.cfg_checker import CheckedBB -from guppy.compiler.core import CompilerBase, DFContainer, CompiledGlobals, \ - PortVariable, return_var +from guppy.compiler.core import ( + CompilerBase, + DFContainer, + CompiledGlobals, + PortVariable, + return_var, +) from guppy.compiler.expr_compiler import ExprCompiler from guppy.error import InternalGuppyError from guppy.guppy_types import TupleType @@ -86,6 +91,6 @@ def visit_Return(self, node: ast.Return) -> None: def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: from guppy.compiler.func_compiler import compile_local_func_def - self.dfg[node.name] = compile_local_func_def(node, self.dfg, self.graph, self.globals) - - + self.dfg[node.name] = compile_local_func_def( + node, self.dfg, self.graph, self.globals + ) diff --git a/guppy/custom.py b/guppy/custom.py index c888746f..0c03f2a8 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -8,8 +8,12 @@ from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.compiler.expr_compiler import ExprCompiler -from guppy.error import GuppyError, InternalGuppyError, UnknownFunctionType, \ - GuppyTypeError +from guppy.error import ( + GuppyError, + InternalGuppyError, + UnknownFunctionType, + GuppyTypeError, +) from guppy.guppy_types import GuppyType, FunctionType, type_to_row, TupleType from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode @@ -31,7 +35,15 @@ class CustomFunction(CompiledFunction): _ty: Optional[FunctionType] = None _defined: dict[Node, DFContainingVNode] = {} - def __init__(self, name: str, defined_at: Optional[ast.FunctionDef], compiler: "CustomCallCompiler", checker: "CustomCallChecker", higher_order_value: bool = True, ty: Optional[FunctionType] = None): + def __init__( + self, + name: str, + defined_at: Optional[ast.FunctionDef], + compiler: "CustomCallCompiler", + checker: "CustomCallChecker", + higher_order_value: bool = True, + ty: Optional[FunctionType] = None, + ): self.name = name self.defined_at = defined_at self.higher_order_value = higher_order_value @@ -73,11 +85,15 @@ def check_type(self, globals: Globals) -> None: if self.call_checker is None or self.higher_order_value: raise err - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> ast.expr: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> ast.expr: self.call_checker._setup(ctx, node, self) return with_type(ty, with_loc(node, self.call_checker.check(args, ty))) - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: "Context" + ) -> tuple[ast.expr, GuppyType]: self.call_checker._setup(ctx, node, self) new_node, ty = self.call_checker.synthesize(args) return with_type(ty, with_loc(node, new_node)), ty @@ -93,7 +109,9 @@ def compile_call( self.call_compiler._setup(dfg, graph, globals, node) return self.call_compiler.compile(args) - def load(self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + def load( + self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: """Loads the custom function as a value into a local dataflow graph. This will place a `FunctionDef` node into the Hugr module if one for this @@ -175,7 +193,9 @@ class CustomCallCompiler(ABC): globals: CompiledGlobals node: AstNode - def _setup(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> None: + def _setup( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> None: self.dfg = dfg self.graph = graph self.globals = globals @@ -226,6 +246,5 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: class NoopCompiler(CustomCallCompiler): - def compile(self, args: list[OutPortV]) -> list[OutPortV]: return args diff --git a/guppy/declared.py b/guppy/declared.py index 167f9abd..257c74ad 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -20,18 +20,26 @@ class DeclaredFunction(CompiledFunction): node: Optional[VNode] = None @staticmethod - def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DeclaredFunction": + def from_ast( + func_def: ast.FunctionDef, name: str, globals: Globals + ) -> "DeclaredFunction": ty = check_signature(func_def, globals) if not has_empty_body(func_def): - raise GuppyError("Body of function declaration must be empty", func_def.body[0]) + raise GuppyError( + "Body of function declaration must be empty", func_def.body[0] + ) return DeclaredFunction(name, ty, func_def, None) - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> GlobalCall: # Use default implementation from the expression checker args = check_call(self.ty, args, ty, node, ctx) return GlobalCall(func=self, args=args) - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty = synthesize_call(self.ty, args, node, ctx) return GlobalCall(func=self, args=args), ty @@ -39,11 +47,20 @@ def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) - def load(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: assert self.node is not None return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) - def compile_call(self, args: list[OutPortV], dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> list[OutPortV]: + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: assert self.node is not None call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/decorator.py b/guppy/decorator.py index 894a94fa..027e032f 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -3,8 +3,14 @@ from typing import Optional, Union, Callable, Any from guppy.ast_util import AstNode, has_empty_body -from guppy.custom import CustomFunction, OpCompiler, DefaultCallChecker, \ - CustomCallCompiler, CustomCallChecker, DefaultCallCompiler +from guppy.custom import ( + CustomFunction, + OpCompiler, + DefaultCallChecker, + CustomCallCompiler, + CustomCallChecker, + DefaultCallCompiler, +) from guppy.error import GuppyError, pretty_errors from guppy.guppy_types import GuppyType from guppy.hugr import tys, ops @@ -29,13 +35,16 @@ def set_module(self, module: GuppyModule) -> None: self._module = module @pretty_errors - def __call__(self, arg: Union[PyFunc, GuppyModule]) -> Union[Optional[Hugr], FuncDecorator]: + def __call__( + self, arg: Union[PyFunc, GuppyModule] + ) -> Union[Optional[Hugr], FuncDecorator]: """Decorator to annotate Python functions as Guppy code. Optionally, the `GuppyModule` in which the function should be placed can be passed to the decorator. """ if isinstance(arg, GuppyModule): + def dec(f: Callable[..., Any]) -> Callable[..., Any]: assert isinstance(arg, GuppyModule) arg.register_func_def(f) @@ -66,7 +75,13 @@ def dec(c: type) -> type: return dec @pretty_errors - def type(self, module: GuppyModule, hugr_ty: tys.SimpleType, name: str = "", linear: bool = False) -> ClassDecorator: + def type( + self, + module: GuppyModule, + hugr_ty: tys.SimpleType, + name: str = "", + linear: bool = False, + ) -> ClassDecorator: """Decorator to annotate a class definitions as Guppy types. Requires the static Hugr translation of the type. Additionally, the type can be @@ -111,7 +126,14 @@ def __str__(self) -> str: return dec @pretty_errors - def custom(self, module: GuppyModule, compiler: Optional[CustomCallCompiler] = None, checker: Optional[CustomCallChecker] = None, higher_order_value: bool = True, name: str = "") -> CustomFuncDecorator: + def custom( + self, + module: GuppyModule, + compiler: Optional[CustomCallCompiler] = None, + checker: Optional[CustomCallChecker] = None, + higher_order_value: bool = True, + name: str = "", + ) -> CustomFuncDecorator: """Decorator to add custom typing or compilation behaviour to function decls. Optionally, usage of the function as a higher-order value can be disabled. In @@ -124,22 +146,36 @@ def dec(f: PyFunc) -> CustomFunction: if not has_empty_body(func_ast): raise GuppyError( "Body of custom function declaration must be empty", - func_ast.body[0] + func_ast.body[0], ) call_checker = checker or DefaultCallChecker() - func = CustomFunction(name or func_ast.name, func_ast, compiler or DefaultCallCompiler(), call_checker, higher_order_value) + func = CustomFunction( + name or func_ast.name, + func_ast, + compiler or DefaultCallCompiler(), + call_checker, + higher_order_value, + ) call_checker.func = func module.register_custom_func(func) return func return dec - def hugr_op(self, module: GuppyModule, op: ops.OpType, checker: Optional[CustomCallChecker] = None, higher_order_value: bool = True, name: str = "") -> CustomFuncDecorator: + def hugr_op( + self, + module: GuppyModule, + op: ops.OpType, + checker: Optional[CustomCallChecker] = None, + higher_order_value: bool = True, + name: str = "", + ) -> CustomFuncDecorator: """Decorator to annotate function declarations as HUGR ops.""" return self.custom(module, OpCompiler(op), checker, higher_order_value, name) def declare(self, module: GuppyModule) -> FuncDecorator: """Decorator to declare functions""" + def dec(f: Callable[..., Any]) -> Callable[..., Any]: module.register_func_decl(f) @@ -154,6 +190,4 @@ def dummy(*args: Any, **kwargs: Any) -> Any: return dec - - guppy = _Guppy() diff --git a/guppy/error.py b/guppy/error.py index 81dfec17..77a08cea 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -176,4 +176,3 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return None return cast(FuncT, wrapped) - diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index fa34bbef..d9f916e6 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -12,7 +12,9 @@ GuppyType, TupleType, FunctionType, - SumType, type_to_row, row_to_type, + SumType, + type_to_row, + row_to_type, ) from guppy.hugr import val @@ -483,7 +485,11 @@ def add_call( """Adds a `Call` node to the graph.""" assert isinstance(def_port.ty, FunctionType) return self.add_node( - ops.Call(), None, list(type_to_row(def_port.ty.returns)), parent, args + [def_port] + ops.Call(), + None, + list(type_to_row(def_port.ty.returns)), + parent, + args + [def_port], ) def add_indirect_call( @@ -628,7 +634,9 @@ def remove_dummy_nodes(self) -> "Hugr": for n in list(self.nodes()): if isinstance(n, VNode) and isinstance(n.op, ops.DummyOp): name = n.op.name - fun_ty = FunctionType(list(n.in_port_types), row_to_type(n.out_port_types)) + fun_ty = FunctionType( + list(n.in_port_types), row_to_type(n.out_port_types) + ) if name in used_names: used_names[name] += 1 name = f"{name}${used_names[name]}" diff --git a/guppy/module.py b/guppy/module.py index 2b501291..a88a5dc5 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -61,6 +61,7 @@ def __init__(self, name: str, import_builtins: bool = True): # Import builtin module if import_builtins: import guppy.prelude.builtins as builtins + self.load(builtins) def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: @@ -72,7 +73,9 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: m.compile() # For now, we can only import custom functions - if any(not isinstance(v, CustomFunction) for v in m._compiled_globals.values()): + if any( + not isinstance(v, CustomFunction) for v in m._compiled_globals.values() + ): raise GuppyError( "Importing modules with defined functions is not supported yet" ) @@ -84,14 +87,18 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: if isinstance(val, GuppyModule): self.load(val) - def register_func_def(self, f: PyFunc, instance: Optional[type[GuppyType]] = None) -> None: + def register_func_def( + self, f: PyFunc, instance: Optional[type[GuppyType]] = None + ) -> None: """Registers a Python function definition as belonging to this Guppy module.""" self._check_not_yet_compiled() func_ast = parse_py_func(f) if self._instance_func_buffer is not None: self._instance_func_buffer[func_ast.name] = f else: - name = qualified_name(instance, func_ast.name) if instance else func_ast.name + name = ( + qualified_name(instance, func_ast.name) if instance else func_ast.name + ) self._check_name_available(name, func_ast) self._func_defs[name] = func_ast @@ -102,7 +109,9 @@ def register_func_decl(self, f: PyFunc) -> None: self._check_name_available(func_ast.name, func_ast) self._func_decls[func_ast.name] = func_ast - def register_custom_func(self, func: CustomFunction, instance: Optional[type[GuppyType]] = None) -> None: + def register_custom_func( + self, func: CustomFunction, instance: Optional[type[GuppyType]] = None + ) -> None: """Registers a custom function as belonging to this Guppy module.""" self._check_not_yet_compiled() if self._instance_func_buffer is not None: @@ -153,7 +162,10 @@ def compile(self) -> Optional[Hugr]: self._globals.values.update(defined_funcs) # Type check function definitions - checked = {x: check_global_func_def(f, self._imported_globals | self._globals) for x, f in defined_funcs.items()} + checked = { + x: check_global_func_def(f, self._imported_globals | self._globals) + for x, f in defined_funcs.items() + } # Add declared functions to the graph graph = Hugr(self.name) @@ -163,14 +175,23 @@ def compile(self) -> Optional[Hugr]: # Prepare `FunctionDef` nodes for all function definitions def_nodes = {x: graph.add_def(f.ty, module_node, x) for x, f in checked.items()} - self._compiled_globals |= self._custom_funcs | declared_funcs | { - x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) - for x, f in checked.items() - } + self._compiled_globals |= ( + self._custom_funcs + | declared_funcs + | { + x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) + for x, f in checked.items() + } + ) # Compile function definitions to Hugr for x, f in checked.items(): - compile_global_func_def(f, def_nodes[x], graph, self._imported_compiled_globals | self._compiled_globals) + compile_global_func_def( + f, + def_nodes[x], + graph, + self._imported_compiled_globals | self._compiled_globals, + ) self._compiled = True return graph diff --git a/guppy/nodes.py b/guppy/nodes.py index a84f9de0..56795f1d 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -14,9 +14,7 @@ class LocalName(ast.expr): id: str - _fields = ( - 'id', - ) + _fields = ("id",) class GlobalName(ast.expr): @@ -24,8 +22,8 @@ class GlobalName(ast.expr): value: "Variable" _fields = ( - 'id', - 'value', + "id", + "value", ) @@ -34,8 +32,8 @@ class LocalCall(ast.expr): args: list[ast.expr] _fields = ( - 'func', - 'args', + "func", + "args", ) @@ -46,8 +44,8 @@ class GlobalCall(ast.expr): # Later: Inferred type args _fields = ( - 'func', - 'args', + "func", + "args", ) @@ -66,7 +64,14 @@ class CheckedNestedFunctionDef(ast.FunctionDef): ty: FunctionType captured: Mapping[str, "Variable"] - def __init__(self, cfg: "CheckedCFG", ty: FunctionType, captured: Mapping[str, "Variable"], *args: Any, **kwargs: Any) -> None: + def __init__( + self, + cfg: "CheckedCFG", + ty: FunctionType, + captured: Mapping[str, "Variable"], + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.cfg = cfg self.ty = ty diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 7c410f7c..c7b257e6 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -6,8 +6,12 @@ from guppy.ast_util import with_type, AstNode, with_loc, get_type from guppy.checker.core import Context, CallableVariable from guppy.checker.expr_checker import ExprSynthesizer, check_num_args -from guppy.custom import CustomCallChecker, DefaultCallChecker, CustomFunction, \ - CustomCallCompiler +from guppy.custom import ( + CustomCallChecker, + DefaultCallChecker, + CustomFunction, + CustomCallCompiler, +) from guppy.error import GuppyTypeError from guppy.guppy_types import GuppyType, type_to_row, FunctionType, BoolType from guppy.hugr import ops, tys, val @@ -69,7 +73,9 @@ def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) -def int_op(op_name: str, ext: str = "arithmetic.int", num_params: int = 1) -> ops.OpType: +def int_op( + op_name: str, ext: str = "arithmetic.int", num_params: int = 1 +) -> ops.OpType: """Utility method to create Hugr integer arithmetic ops.""" return ops.CustomOp( extension=ext, @@ -92,8 +98,12 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: for i in range(len(args)): args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) if isinstance(ty, self.ctx.globals.types["int"]): - call = with_loc(self.node, GlobalCall(func=Int.__float__, args=[args[i]])) - args[i] = with_type(self.ctx.globals.types["float"].build(node=self.node), call) + call = with_loc( + self.node, GlobalCall(func=Int.__float__, args=[args[i]]) + ) + args[i] = with_type( + self.ctx.globals.types["float"].build(node=self.node), call + ) return super().synthesize(args) @@ -148,7 +158,9 @@ def __init__(self, dunder_name: str, num_args: int = 1): self.dunder_name = dunder_name self.num_args = num_args - def _get_func(self, args: list[ast.expr]) -> tuple[list[ast.expr], CallableVariable]: + def _get_func( + self, args: list[ast.expr] + ) -> tuple[list[ast.expr], CallableVariable]: check_num_args(self.num_args, len(args), self.node) fst, *rest = args fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) @@ -193,9 +205,15 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # Compile `truediv` using float arithmetic [left, right] = args - [left] = Int.__float__.compile_call([left], self.dfg, self.graph, self.globals, self.node) - [right] = Int.__float__.compile_call([right], self.dfg, self.graph, self.globals, self.node) - return Float.__truediv__.compile_call([left, right], self.dfg, self.graph, self.globals, self.node) + [left] = Int.__float__.compile_call( + [left], self.dfg, self.graph, self.globals, self.node + ) + [right] = Int.__float__.compile_call( + [right], self.dfg, self.graph, self.globals, self.node + ) + return Float.__truediv__.compile_call( + [left, right], self.dfg, self.graph, self.globals, self.node + ) class FloatBoolCompiler(CustomCallCompiler): @@ -205,7 +223,9 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: from .builtins import Float # We have: bool(x) = (x != 0.0) - zero_const = self.graph.add_constant(float_value(0.0), get_type(self.node), self.dfg.node) + zero_const = self.graph.add_constant( + float_value(0.0), get_type(self.node), self.dfg.node + ) zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) return Float.__ne__.compile_call( [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node @@ -260,4 +280,4 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: [mod] = Float.__mod__.compile_call( args, self.dfg, self.graph, self.globals, self.node ) - return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] \ No newline at end of file + return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index c684aac6..d83d4633 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -7,10 +7,23 @@ from guppy.guppy_types import BoolType from guppy.hugr import tys from guppy.module import GuppyModule -from guppy.prelude._internal import logic_op, int_op, hugr_int_type, hugr_float_type, \ - float_op, CoercingChecker, ReversingChecker, IntTruedivCompiler, FloatBoolCompiler, \ - FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, \ - NotImplementedCompiler, DunderChecker, CallableChecker +from guppy.prelude._internal import ( + logic_op, + int_op, + hugr_int_type, + hugr_float_type, + float_op, + CoercingChecker, + ReversingChecker, + IntTruedivCompiler, + FloatBoolCompiler, + FloatDivmodCompiler, + FloatFloordivCompiler, + FloatModCompiler, + NotImplementedCompiler, + DunderChecker, + CallableChecker, +) builtins = GuppyModule("builtins", import_builtins=False) @@ -18,7 +31,6 @@ @guppy.extend_type(builtins, BoolType) class Bool: - @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) def __and__(self: bool, other: bool) -> bool: ... @@ -38,7 +50,6 @@ def __or__(self: bool, other: bool) -> bool: @guppy.type(builtins, hugr_int_type, name="int") class Int: - @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) def __abs__(self: int) -> int: ... @@ -131,7 +142,9 @@ def __or__(self: int, other: int) -> int: def __pos__(self: int) -> int: ... - @guppy.custom(builtins, NotImplementedCompiler("ipow"), DefaultCallChecker()) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("ipow"), DefaultCallChecker() + ) # TODO def __pow__(self: int, other: int) -> int: ... @@ -143,19 +156,29 @@ def __radd__(self: int, other: int) -> int: def __rand__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op( + builtins, + int_op("idivmod_s", num_params=2), + ReversingChecker(DefaultCallChecker()), + ) def __rdivmod__(self: int, other: int) -> tuple[int, int]: ... - @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op( + builtins, int_op("idiv_s", num_params=2), ReversingChecker(DefaultCallChecker()) + ) def __rfloordiv__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("ishl", num_params=2), ReversingChecker(DefaultCallChecker())) # TODO: RHS is unsigned + @guppy.hugr_op( + builtins, int_op("ishl", num_params=2), ReversingChecker(DefaultCallChecker()) + ) # TODO: RHS is unsigned def __rlshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op( + builtins, int_op("imod_s", num_params=2), ReversingChecker(DefaultCallChecker()) + ) def __rmod__(self: int, other: int) -> int: ... @@ -171,11 +194,15 @@ def __ror__(self: int, other: int) -> int: def __round__(self: int) -> int: ... - @guppy.custom(builtins, NotImplementedCompiler("ipow"), ReversingChecker(DefaultCallChecker())) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("ipow"), ReversingChecker(DefaultCallChecker()) + ) # TODO def __rpow__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("ishr", num_params=2), ReversingChecker(DefaultCallChecker())) # TODO: RHS is unsigned + @guppy.hugr_op( + builtins, int_op("ishr", num_params=2), ReversingChecker(DefaultCallChecker()) + ) # TODO: RHS is unsigned def __rrshift__(self: int, other: int) -> int: ... @@ -187,7 +214,9 @@ def __rshift__(self: int, other: int) -> int: def __rsub__(self: int, other: int) -> int: ... - @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker(DefaultCallChecker())) + @guppy.custom( + builtins, IntTruedivCompiler(), ReversingChecker(DefaultCallChecker()) + ) def __rtruediv__(self: int, other: int) -> float: ... @@ -214,7 +243,6 @@ def __xor__(self: int, other: int) -> int: @guppy.type(builtins, hugr_float_type, name="float") class Float: - @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) def __abs__(self: float) -> float: ... @@ -259,7 +287,9 @@ def __ge__(self: float, other: float) -> bool: def __gt__(self: float, other: float) -> bool: ... - @guppy.hugr_op(builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker()) + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) def __int__(self: float) -> int: ... @@ -303,7 +333,9 @@ def __radd__(self: float, other: float) -> float: def __rdivmod__(self: float, other: float) -> tuple[float, float]: ... - @guppy.custom(builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker())) + @guppy.custom( + builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker()) + ) def __rfloordiv__(self: float, other: float) -> float: ... @@ -315,11 +347,15 @@ def __rmod__(self: float, other: float) -> float: def __rmul__(self: float, other: float) -> float: ... - @guppy.custom(builtins, NotImplementedCompiler("fround"), ReversingChecker(CoercingChecker())) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("fround"), ReversingChecker(CoercingChecker()) + ) # TODO def __round__(self: float) -> float: ... - @guppy.custom(builtins, NotImplementedCompiler("fpow"), ReversingChecker(CoercingChecker())) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("fpow"), ReversingChecker(CoercingChecker()) + ) # TODO def __rpow__(self: float, other: float) -> float: ... @@ -339,7 +375,9 @@ def __sub__(self: float, other: float) -> float: def __truediv__(self: float, other: float) -> float: ... - @guppy.hugr_op(builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker()) + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) def __trunc__(self: float) -> float: ... @@ -349,7 +387,9 @@ def abs(x): ... -@guppy.custom(builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False) +@guppy.custom( + builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False +) def _bool(x): ... @@ -359,22 +399,30 @@ def callable(x): ... -@guppy.custom(builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False) +@guppy.custom( + builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False +) def divmod(x, y): ... -@guppy.custom(builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False) +@guppy.custom( + builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False +) def _float(x, y): ... -@guppy.custom(builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False) +@guppy.custom( + builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False +) def _int(x): ... -@guppy.custom(builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False) +@guppy.custom( + builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False +) def pow(x, y): ... @@ -382,4 +430,3 @@ def pow(x, y): @guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) def round(x): ... - diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index fd1b9bc3..3b40af88 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -16,7 +16,11 @@ def quantum_op(op_name: str) -> ops.OpType: return ops.CustomOp(extension="quantum.tket2", op_name=op_name, args=[]) -@guppy.type(quantum, tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), linear=True) +@guppy.type( + quantum, + tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), + linear=True, +) class Qubit: pass From de454bdc98b1f986dfc8f973a336a97de116abc5 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 16:41:26 +0000 Subject: [PATCH 14/77] Rename guppy_types to types --- guppy/__init__.py | 2 +- guppy/ast_util.py | 4 ++-- guppy/cfg/builder.py | 2 +- guppy/checker/cfg_checker.py | 2 +- guppy/checker/core.py | 2 +- guppy/checker/expr_checker.py | 2 +- guppy/checker/func_checker.py | 2 +- guppy/checker/stmt_checker.py | 2 +- guppy/compiler/cfg_compiler.py | 2 +- guppy/compiler/core.py | 2 +- guppy/compiler/expr_compiler.py | 2 +- guppy/compiler/func_compiler.py | 2 +- guppy/compiler/stmt_compiler.py | 2 +- guppy/custom.py | 2 +- guppy/declared.py | 2 +- guppy/decorator.py | 2 +- guppy/error.py | 2 +- guppy/hugr/hugr.py | 2 +- guppy/module.py | 2 +- guppy/nodes.py | 2 +- guppy/prelude/_internal.py | 2 +- guppy/prelude/builtins.py | 2 +- guppy/{guppy_types.py => types.py} | 0 tests/hugr/test_dummy_nodes.py | 2 +- tests/hugr/test_ports.py | 2 +- 25 files changed, 25 insertions(+), 25 deletions(-) rename guppy/{guppy_types.py => types.py} (100%) diff --git a/guppy/__init__.py b/guppy/__init__.py index ff392293..bbe52f53 100644 --- a/guppy/__init__.py +++ b/guppy/__init__.py @@ -1 +1 @@ -__all__ = ["guppy_types"] +__all__ = ["types.py"] diff --git a/guppy/ast_util.py b/guppy/ast_util.py index bfe6bae9..cbd78448 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -2,7 +2,7 @@ from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING if TYPE_CHECKING: - from guppy.guppy_types import GuppyType + from guppy.types import GuppyType AstNode = Union[ ast.AST, @@ -179,7 +179,7 @@ def with_type(ty: "GuppyType", node: A) -> A: def get_type_opt(node: AstNode) -> Optional["GuppyType"]: """Tries to retrieve a type annotation from an AST node.""" - from guppy.guppy_types import GuppyType + from guppy.types import GuppyType try: ty = getattr(node, "type") diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 56c61310..9428f700 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -7,7 +7,7 @@ from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError -from guppy.guppy_types import NoneType +from guppy.types import NoneType from guppy.nodes import NestedFunctionDef # In order to build expressions, need an endless stream of unique temporary variables diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index fdcfb54a..2085f8c1 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -11,7 +11,7 @@ from guppy.checker.expr_checker import ExprSynthesizer, to_bool from guppy.checker.stmt_checker import StmtChecker from guppy.error import GuppyError -from guppy.guppy_types import GuppyType +from guppy.types import GuppyType VarRow = Sequence[Variable] diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 77cde03e..d64ab3de 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -4,7 +4,7 @@ from typing import NamedTuple, Optional, Union from guppy.ast_util import AstNode -from guppy.guppy_types import ( +from guppy.types import ( GuppyType, FunctionType, TupleType, diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index a5297f57..17a7c591 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.guppy_types import GuppyType, TupleType, FunctionType, BoolType +from guppy.types import GuppyType, TupleType, FunctionType, BoolType from guppy.nodes import LocalName, GlobalName, LocalCall # Mapping from unary AST op to dunder method and display name diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index facd10b6..290c6add 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -8,7 +8,7 @@ from guppy.checker.cfg_checker import check_cfg, CheckedCFG from guppy.checker.expr_checker import synthesize_call, check_call from guppy.error import GuppyError -from guppy.guppy_types import FunctionType, type_from_ast, NoneType, GuppyType +from guppy.types import FunctionType, type_from_ast, NoneType, GuppyType from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 5030c649..8b9dc77f 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -6,7 +6,7 @@ from guppy.checker.core import Variable, Context from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.guppy_types import GuppyType, TupleType, type_from_ast, NoneType +from guppy.types import GuppyType, TupleType, type_from_ast, NoneType from guppy.nodes import NestedFunctionDef diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index 7446af8e..fa9277b3 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -12,7 +12,7 @@ ) from guppy.compiler.expr_compiler import ExprCompiler from guppy.compiler.stmt_compiler import StmtCompiler -from guppy.guppy_types import TupleType, SumType, type_to_row +from guppy.types import TupleType, SumType, type_to_row from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 2e616e5a..093ee0f6 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstNode from guppy.checker.core import Variable, CallableVariable -from guppy.guppy_types import FunctionType +from guppy.types import FunctionType from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index b5c4c65b..75017cbc 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstVisitor, get_type from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction from guppy.error import InternalGuppyError -from guppy.guppy_types import FunctionType, type_to_row, BoolType +from guppy.types import FunctionType, type_to_row, BoolType from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index ee032cc9..60c7f526 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -9,7 +9,7 @@ DFContainer, PortVariable, ) -from guppy.guppy_types import type_to_row, FunctionType +from guppy.types import type_to_row, FunctionType from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index 61b80dff..a8d9c5c8 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -12,7 +12,7 @@ ) from guppy.compiler.expr_compiler import ExprCompiler from guppy.error import InternalGuppyError -from guppy.guppy_types import TupleType +from guppy.types import TupleType from guppy.hugr.hugr import Hugr, OutPortV from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/custom.py b/guppy/custom.py index 0c03f2a8..bb548e4e 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -14,7 +14,7 @@ UnknownFunctionType, GuppyTypeError, ) -from guppy.guppy_types import GuppyType, FunctionType, type_to_row, TupleType +from guppy.types import GuppyType, FunctionType, type_to_row, TupleType from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode from guppy.nodes import GlobalCall diff --git a/guppy/declared.py b/guppy/declared.py index 257c74ad..854e3823 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -8,7 +8,7 @@ from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError -from guppy.guppy_types import type_to_row, GuppyType +from guppy.types import type_to_row, GuppyType from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV from guppy.nodes import GlobalCall diff --git a/guppy/decorator.py b/guppy/decorator.py index 027e032f..6220c8aa 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -12,7 +12,7 @@ DefaultCallCompiler, ) from guppy.error import GuppyError, pretty_errors -from guppy.guppy_types import GuppyType +from guppy.types import GuppyType from guppy.hugr import tys, ops from guppy.hugr.hugr import Hugr from guppy.module import GuppyModule, PyFunc, parse_py_func diff --git a/guppy/error.py b/guppy/error.py index 77a08cea..550429f0 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -6,7 +6,7 @@ from typing import Optional, Any, Sequence, Callable, TypeVar, cast from guppy.ast_util import AstNode, get_line_offset, get_file, get_source -from guppy.guppy_types import GuppyType, FunctionType +from guppy.types import GuppyType, FunctionType from guppy.hugr.hugr import OutPortV, Node diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index d9f916e6..69d2fa84 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -8,7 +8,7 @@ import guppy.hugr.ops as ops import guppy.hugr.raw as raw -from guppy.guppy_types import ( +from guppy.types import ( GuppyType, TupleType, FunctionType, diff --git a/guppy/module.py b/guppy/module.py index a88a5dc5..fb29bb6d 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -13,7 +13,7 @@ from guppy.custom import CustomFunction from guppy.declared import DeclaredFunction from guppy.error import GuppyError, pretty_errors -from guppy.guppy_types import GuppyType +from guppy.types import GuppyType from guppy.hugr.hugr import Hugr PyFunc = Callable[..., Any] diff --git a/guppy/nodes.py b/guppy/nodes.py index 56795f1d..919caf82 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -3,7 +3,7 @@ import ast from typing import TYPE_CHECKING, Any, Mapping -from guppy.guppy_types import FunctionType +from guppy.types import FunctionType if TYPE_CHECKING: from guppy.cfg.cfg import CFG diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index c7b257e6..39816296 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -13,7 +13,7 @@ CustomCallCompiler, ) from guppy.error import GuppyTypeError -from guppy.guppy_types import GuppyType, type_to_row, FunctionType, BoolType +from guppy.types import GuppyType, type_to_row, FunctionType, BoolType from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV from guppy.nodes import GlobalCall diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index d83d4633..e3c7c19c 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -4,7 +4,7 @@ from guppy.custom import NoopCompiler, DefaultCallChecker from guppy.decorator import guppy -from guppy.guppy_types import BoolType +from guppy.types import BoolType from guppy.hugr import tys from guppy.module import GuppyModule from guppy.prelude._internal import ( diff --git a/guppy/guppy_types.py b/guppy/types.py similarity index 100% rename from guppy/guppy_types.py rename to guppy/types.py diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index f42ba4b1..73b7ee2d 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -1,4 +1,4 @@ -from guppy.guppy_types import FunctionType, BoolType, TupleType +from guppy.types import FunctionType, BoolType, TupleType from guppy.hugr import ops from guppy.hugr.hugr import Hugr diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py index 0e1d74e1..dd9fd7a0 100644 --- a/tests/hugr/test_ports.py +++ b/tests/hugr/test_ports.py @@ -1,7 +1,7 @@ import pytest from guppy.error import UndefinedPort, InternalGuppyError -from guppy.guppy_types import BoolType +from guppy.types import BoolType def test_undefined_port(): From 4dba8a7caf0878f8e9885d128aeb66e344bf4136 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 21 Nov 2023 17:39:06 +0000 Subject: [PATCH 15/77] Make node optional for type building --- guppy/decorator.py | 4 +++- guppy/prelude/_internal.py | 4 +--- guppy/types.py | 17 +++++++++++------ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/guppy/decorator.py b/guppy/decorator.py index 6220c8aa..67948cec 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -98,7 +98,9 @@ class NewType(GuppyType): name = _name @staticmethod - def build(*args: GuppyType, node: AstNode) -> "GuppyType": + def build( + *args: GuppyType, node: Optional[AstNode] = None + ) -> "GuppyType": # At the moment, custom types don't support type arguments. if len(args) > 0: raise GuppyError( diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 39816296..7a06ec5d 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -101,9 +101,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: call = with_loc( self.node, GlobalCall(func=Int.__float__, args=[args[i]]) ) - args[i] = with_type( - self.ctx.globals.types["float"].build(node=self.node), call - ) + args[i] = with_type(self.ctx.globals.types["float"].build(), call) return super().synthesize(args) diff --git a/guppy/types.py b/guppy/types.py index 4cfa330e..ca5c010a 100644 --- a/guppy/types.py +++ b/guppy/types.py @@ -20,7 +20,7 @@ class GuppyType(ABC): @staticmethod @abstractmethod - def build(*args: "GuppyType", node: AstNode) -> "GuppyType": + def build(*args: "GuppyType", node: Optional[AstNode] = None) -> "GuppyType": pass @property @@ -53,7 +53,7 @@ def __str__(self) -> str: return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" @staticmethod - def build(*args: GuppyType, node: AstNode) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: # Function types cannot be constructed using `build`. The type parsing code # has a special case for function types. raise NotImplementedError() @@ -71,7 +71,12 @@ class TupleType(GuppyType): name: str = "tuple" @staticmethod - def build(*args: GuppyType, node: AstNode) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + from guppy.error import GuppyError + + # TODO: Parse empty tuples via `tuple[()]` + if len(args) == 0: + raise GuppyError("Tuple type requires generic type arguments", node) return TupleType(list(args)) def __str__(self) -> str: @@ -91,7 +96,7 @@ class SumType(GuppyType): element_types: Sequence[GuppyType] @staticmethod - def build(*args: GuppyType, node: AstNode) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: # Sum types cannot be parsed and constructed using `build` since they cannot be # written by the user raise NotImplementedError() @@ -118,7 +123,7 @@ class NoneType(GuppyType): linear: bool = False @staticmethod - def build(*args: GuppyType, node: AstNode) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: if len(args) > 0: from guppy.error import GuppyError @@ -144,7 +149,7 @@ def __init__(self) -> None: super().__init__([TupleType([]), TupleType([])]) @staticmethod - def build(*args: GuppyType, node: AstNode) -> GuppyType: + def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: if len(args) > 0: from guppy.error import GuppyError From 3b0d0f4ee1fcb2be93822ed3e3325ff300ed46ee Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 09:38:01 +0000 Subject: [PATCH 16/77] Get rid of NotImplementedCompiler --- guppy/prelude/_internal.py | 17 +---------------- guppy/prelude/builtins.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 7a06ec5d..4d0ce447 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -12,7 +12,7 @@ CustomFunction, CustomCallCompiler, ) -from guppy.error import GuppyTypeError +from guppy.error import GuppyTypeError, GuppyError from guppy.types import GuppyType, type_to_row, FunctionType, BoolType from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV @@ -130,21 +130,6 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: return expr, ty -class NotImplementedCompiler(CustomCallCompiler): - """Compiler for functions that are not yet implemented.""" - - name: str - - def __init__(self, name: str) -> None: - self.name = name - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - node = self.graph.add_node(ops.DummyOp(name=self.name), inputs=args) - return_ty = get_type(self.node) - assert return_ty is not None - return [node.add_out_port(ty) for ty in type_to_row(return_ty)] - - class DunderChecker(CustomCallChecker): """Call checker for builtin functions that call out to dunder instance methods""" diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index e3c7c19c..793e32d0 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -5,7 +5,7 @@ from guppy.custom import NoopCompiler, DefaultCallChecker from guppy.decorator import guppy from guppy.types import BoolType -from guppy.hugr import tys +from guppy.hugr import tys, ops from guppy.module import GuppyModule from guppy.prelude._internal import ( logic_op, @@ -20,7 +20,6 @@ FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, - NotImplementedCompiler, DunderChecker, CallableChecker, ) @@ -142,9 +141,7 @@ def __or__(self: int, other: int) -> int: def __pos__(self: int) -> int: ... - @guppy.custom( - builtins, NotImplementedCompiler("ipow"), DefaultCallChecker() - ) # TODO + @guppy.hugr_op(builtins, ops.DummyOp(name="ipow")) # TODO def __pow__(self: int, other: int) -> int: ... @@ -194,8 +191,8 @@ def __ror__(self: int, other: int) -> int: def __round__(self: int) -> int: ... - @guppy.custom( - builtins, NotImplementedCompiler("ipow"), ReversingChecker(DefaultCallChecker()) + @guppy.hugr_op( + builtins, ops.DummyOp(name="ipow"), ReversingChecker(DefaultCallChecker()) ) # TODO def __rpow__(self: int, other: int) -> int: ... @@ -321,7 +318,7 @@ def __neg__(self: float, other: float) -> float: def __pos__(self: float) -> float: ... - @guppy.custom(builtins, NotImplementedCompiler("fpow"), CoercingChecker()) # TODO + @guppy.hugr_op(builtins, ops.DummyOp(name="fpow")) # TODO def __pow__(self: float, other: float) -> float: ... @@ -347,14 +344,12 @@ def __rmod__(self: float, other: float) -> float: def __rmul__(self: float, other: float) -> float: ... - @guppy.custom( - builtins, NotImplementedCompiler("fround"), ReversingChecker(CoercingChecker()) - ) # TODO + @guppy.hugr_op(builtins, ops.DummyOp(name="fround")) # TODO def __round__(self: float) -> float: ... - @guppy.custom( - builtins, NotImplementedCompiler("fpow"), ReversingChecker(CoercingChecker()) + @guppy.hugr_op( + builtins, ops.DummyOp(name="fpow"), ReversingChecker(DefaultCallChecker()) ) # TODO def __rpow__(self: float, other: float) -> float: ... From b7ccf59cdfeff73715f0d3200a1ca85c08a12d05 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 09:41:26 +0000 Subject: [PATCH 17/77] Make base_checker in ReversingChecker optional --- guppy/prelude/_internal.py | 6 +++--- guppy/prelude/builtins.py | 32 +++++++++++++------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 4d0ce447..b5323fc2 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -13,7 +13,7 @@ CustomCallCompiler, ) from guppy.error import GuppyTypeError, GuppyError -from guppy.types import GuppyType, type_to_row, FunctionType, BoolType +from guppy.types import GuppyType, FunctionType, BoolType from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV from guppy.nodes import GlobalCall @@ -110,8 +110,8 @@ class ReversingChecker(CustomCallChecker): base_checker: CustomCallChecker - def __init__(self, base_checker: CustomCallChecker): - self.base_checker = base_checker + def __init__(self, base_checker: Optional[CustomCallChecker] = None): + self.base_checker = base_checker or DefaultCallChecker() def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: super()._setup(ctx, node, func) diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index 793e32d0..95430f70 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -145,45 +145,39 @@ def __pos__(self: int) -> int: def __pow__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker()) def __radd__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker()) def __rand__(self: int, other: int) -> int: ... - @guppy.hugr_op( - builtins, - int_op("idivmod_s", num_params=2), - ReversingChecker(DefaultCallChecker()), - ) + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker()) def __rdivmod__(self: int, other: int) -> tuple[int, int]: ... @guppy.hugr_op( - builtins, int_op("idiv_s", num_params=2), ReversingChecker(DefaultCallChecker()) + builtins, int_op("idiv_s", num_params=2), ReversingChecker() ) def __rfloordiv__(self: int, other: int) -> int: ... @guppy.hugr_op( - builtins, int_op("ishl", num_params=2), ReversingChecker(DefaultCallChecker()) + builtins, int_op("ishl", num_params=2), ReversingChecker() ) # TODO: RHS is unsigned def __rlshift__(self: int, other: int) -> int: ... - @guppy.hugr_op( - builtins, int_op("imod_s", num_params=2), ReversingChecker(DefaultCallChecker()) - ) + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker()) def __rmod__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker()) def __rmul__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("ior"), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op(builtins, int_op("ior"), ReversingChecker()) def __ror__(self: int, other: int) -> int: ... @@ -192,13 +186,13 @@ def __round__(self: int) -> int: ... @guppy.hugr_op( - builtins, ops.DummyOp(name="ipow"), ReversingChecker(DefaultCallChecker()) + builtins, ops.DummyOp(name="ipow"), ReversingChecker() ) # TODO def __rpow__(self: int, other: int) -> int: ... @guppy.hugr_op( - builtins, int_op("ishr", num_params=2), ReversingChecker(DefaultCallChecker()) + builtins, int_op("ishr", num_params=2), ReversingChecker() ) # TODO: RHS is unsigned def __rrshift__(self: int, other: int) -> int: ... @@ -207,17 +201,17 @@ def __rrshift__(self: int, other: int) -> int: def __rshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker()) def __rsub__(self: int, other: int) -> int: ... @guppy.custom( - builtins, IntTruedivCompiler(), ReversingChecker(DefaultCallChecker()) + builtins, IntTruedivCompiler(), ReversingChecker() ) def __rtruediv__(self: int, other: int) -> float: ... - @guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker()) def __rxor__(self: int, other: int) -> int: ... From c4bcbea7e901e907ebb3d9a303766fc2113b6740 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 10:23:03 +0000 Subject: [PATCH 18/77] Add unsupported builtins --- guppy/prelude/_internal.py | 17 ++ guppy/prelude/builtins.py | 328 ++++++++++++++++++++++++++++++++++++- 2 files changed, 336 insertions(+), 9 deletions(-) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index b5323fc2..a3693de6 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -130,6 +130,23 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: return expr, ty +class UnsupportedChecker(CustomCallChecker): + """Call checker for Python builtin functions that are not available in Guppy. + + Gives the uses a nicer error message when they try to use an unsupported feature. + """ + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + raise GuppyError( + f"Builtin method `{self.func.name}` is not supported by Guppy", self.node + ) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + raise GuppyError( + f"Builtin method `{self.func.name}` is not supported by Guppy", self.node + ) + + class DunderChecker(CustomCallChecker): """Call checker for builtin functions that call out to dunder instance methods""" diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index 95430f70..a47ca48c 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -22,6 +22,7 @@ FloatModCompiler, DunderChecker, CallableChecker, + UnsupportedChecker, ) @@ -157,9 +158,7 @@ def __rand__(self: int, other: int) -> int: def __rdivmod__(self: int, other: int) -> tuple[int, int]: ... - @guppy.hugr_op( - builtins, int_op("idiv_s", num_params=2), ReversingChecker() - ) + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker()) def __rfloordiv__(self: int, other: int) -> int: ... @@ -185,9 +184,7 @@ def __ror__(self: int, other: int) -> int: def __round__(self: int) -> int: ... - @guppy.hugr_op( - builtins, ops.DummyOp(name="ipow"), ReversingChecker() - ) # TODO + @guppy.hugr_op(builtins, ops.DummyOp(name="ipow"), ReversingChecker()) # TODO def __rpow__(self: int, other: int) -> int: ... @@ -205,9 +202,7 @@ def __rshift__(self: int, other: int) -> int: def __rsub__(self: int, other: int) -> int: ... - @guppy.custom( - builtins, IntTruedivCompiler(), ReversingChecker() - ) + @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker()) def __rtruediv__(self: int, other: int) -> float: ... @@ -419,3 +414,318 @@ def pow(x, y): @guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) def round(x): ... + + +# Python builtins that are not supported yet: + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def aiter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def all(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def anext(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def any(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bin(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def breakpoint(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bytearray(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bytes(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def chr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def classmethod(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def compile(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def complex(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def delattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def dict(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def dir(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def enumerate(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def eval(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def exec(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def filter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def format(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def forozenset(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def getattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def globals(): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hasattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hash(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def help(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hex(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def id(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def input(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def isinstance(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def issubclass(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def iter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def len(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def list(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def locals(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def map(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def max(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def memoryview(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def min(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def next(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def object(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def oct(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def open(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def ord(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def print(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def property(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def range(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def repr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def reversed(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def set(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def setattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def slice(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def sorted(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def staticmethod(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def str(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def sum(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def super(x): + ... + + +@guppy.custom( + builtins, name="tuple", checker=UnsupportedChecker(), higher_order_value=False +) +def _tuple(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def type(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def vars(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def zip(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def __import__(x): + ... From ab599bcef6b3129ab0eeca38487d30cba0f384d1 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 10:25:07 +0000 Subject: [PATCH 19/77] Rename types to gtypes --- guppy/ast_util.py | 4 ++-- guppy/cfg/builder.py | 2 +- guppy/checker/cfg_checker.py | 2 +- guppy/checker/core.py | 2 +- guppy/checker/expr_checker.py | 2 +- guppy/checker/func_checker.py | 2 +- guppy/checker/stmt_checker.py | 2 +- guppy/compiler/cfg_compiler.py | 2 +- guppy/compiler/core.py | 2 +- guppy/compiler/expr_compiler.py | 2 +- guppy/compiler/func_compiler.py | 2 +- guppy/compiler/stmt_compiler.py | 2 +- guppy/custom.py | 2 +- guppy/declared.py | 2 +- guppy/decorator.py | 2 +- guppy/error.py | 2 +- guppy/{types.py => gtypes.py} | 0 guppy/hugr/hugr.py | 2 +- guppy/module.py | 2 +- guppy/nodes.py | 2 +- guppy/prelude/_internal.py | 2 +- guppy/prelude/builtins.py | 2 +- tests/hugr/test_dummy_nodes.py | 2 +- tests/hugr/test_ports.py | 2 +- 24 files changed, 24 insertions(+), 24 deletions(-) rename guppy/{types.py => gtypes.py} (100%) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index cbd78448..8c4be4f7 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -2,7 +2,7 @@ from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING if TYPE_CHECKING: - from guppy.types import GuppyType + from guppy.gtypes import GuppyType AstNode = Union[ ast.AST, @@ -179,7 +179,7 @@ def with_type(ty: "GuppyType", node: A) -> A: def get_type_opt(node: AstNode) -> Optional["GuppyType"]: """Tries to retrieve a type annotation from an AST node.""" - from guppy.types import GuppyType + from guppy.gtypes import GuppyType try: ty = getattr(node, "type") diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 9428f700..7285a613 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -7,7 +7,7 @@ from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError -from guppy.types import NoneType +from guppy.gtypes import NoneType from guppy.nodes import NestedFunctionDef # In order to build expressions, need an endless stream of unique temporary variables diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 2085f8c1..1fbb0521 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -11,7 +11,7 @@ from guppy.checker.expr_checker import ExprSynthesizer, to_bool from guppy.checker.stmt_checker import StmtChecker from guppy.error import GuppyError -from guppy.types import GuppyType +from guppy.gtypes import GuppyType VarRow = Sequence[Variable] diff --git a/guppy/checker/core.py b/guppy/checker/core.py index d64ab3de..9bab6375 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -4,7 +4,7 @@ from typing import NamedTuple, Optional, Union from guppy.ast_util import AstNode -from guppy.types import ( +from guppy.gtypes import ( GuppyType, FunctionType, TupleType, diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 17a7c591..747d32c0 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.types import GuppyType, TupleType, FunctionType, BoolType +from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType from guppy.nodes import LocalName, GlobalName, LocalCall # Mapping from unary AST op to dunder method and display name diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 290c6add..87b22b8f 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -8,7 +8,7 @@ from guppy.checker.cfg_checker import check_cfg, CheckedCFG from guppy.checker.expr_checker import synthesize_call, check_call from guppy.error import GuppyError -from guppy.types import FunctionType, type_from_ast, NoneType, GuppyType +from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 8b9dc77f..5b12dc58 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -6,7 +6,7 @@ from guppy.checker.core import Variable, Context from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.types import GuppyType, TupleType, type_from_ast, NoneType +from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType from guppy.nodes import NestedFunctionDef diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index fa9277b3..d18c21b9 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -12,7 +12,7 @@ ) from guppy.compiler.expr_compiler import ExprCompiler from guppy.compiler.stmt_compiler import StmtCompiler -from guppy.types import TupleType, SumType, type_to_row +from guppy.gtypes import TupleType, SumType, type_to_row from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 093ee0f6..2b6a4fb2 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstNode from guppy.checker.core import Variable, CallableVariable -from guppy.types import FunctionType +from guppy.gtypes import FunctionType from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 75017cbc..5fdc378b 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstVisitor, get_type from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction from guppy.error import InternalGuppyError -from guppy.types import FunctionType, type_to_row, BoolType +from guppy.gtypes import FunctionType, type_to_row, BoolType from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index 60c7f526..ab066839 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -9,7 +9,7 @@ DFContainer, PortVariable, ) -from guppy.types import type_to_row, FunctionType +from guppy.gtypes import type_to_row, FunctionType from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index a8d9c5c8..db656b6f 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -12,7 +12,7 @@ ) from guppy.compiler.expr_compiler import ExprCompiler from guppy.error import InternalGuppyError -from guppy.types import TupleType +from guppy.gtypes import TupleType from guppy.hugr.hugr import Hugr, OutPortV from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/custom.py b/guppy/custom.py index bb548e4e..acfd7341 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -14,7 +14,7 @@ UnknownFunctionType, GuppyTypeError, ) -from guppy.types import GuppyType, FunctionType, type_to_row, TupleType +from guppy.gtypes import GuppyType, FunctionType, type_to_row, TupleType from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode from guppy.nodes import GlobalCall diff --git a/guppy/declared.py b/guppy/declared.py index 854e3823..91b5283e 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -8,7 +8,7 @@ from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError -from guppy.types import type_to_row, GuppyType +from guppy.gtypes import type_to_row, GuppyType from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV from guppy.nodes import GlobalCall diff --git a/guppy/decorator.py b/guppy/decorator.py index 67948cec..493ddcb2 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -12,7 +12,7 @@ DefaultCallCompiler, ) from guppy.error import GuppyError, pretty_errors -from guppy.types import GuppyType +from guppy.gtypes import GuppyType from guppy.hugr import tys, ops from guppy.hugr.hugr import Hugr from guppy.module import GuppyModule, PyFunc, parse_py_func diff --git a/guppy/error.py b/guppy/error.py index 550429f0..b0ace141 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -6,7 +6,7 @@ from typing import Optional, Any, Sequence, Callable, TypeVar, cast from guppy.ast_util import AstNode, get_line_offset, get_file, get_source -from guppy.types import GuppyType, FunctionType +from guppy.gtypes import GuppyType, FunctionType from guppy.hugr.hugr import OutPortV, Node diff --git a/guppy/types.py b/guppy/gtypes.py similarity index 100% rename from guppy/types.py rename to guppy/gtypes.py diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 69d2fa84..511fe702 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -8,7 +8,7 @@ import guppy.hugr.ops as ops import guppy.hugr.raw as raw -from guppy.types import ( +from guppy.gtypes import ( GuppyType, TupleType, FunctionType, diff --git a/guppy/module.py b/guppy/module.py index fb29bb6d..e07a15ac 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -13,7 +13,7 @@ from guppy.custom import CustomFunction from guppy.declared import DeclaredFunction from guppy.error import GuppyError, pretty_errors -from guppy.types import GuppyType +from guppy.gtypes import GuppyType from guppy.hugr.hugr import Hugr PyFunc = Callable[..., Any] diff --git a/guppy/nodes.py b/guppy/nodes.py index 919caf82..dfd02349 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -3,7 +3,7 @@ import ast from typing import TYPE_CHECKING, Any, Mapping -from guppy.types import FunctionType +from guppy.gtypes import FunctionType if TYPE_CHECKING: from guppy.cfg.cfg import CFG diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index a3693de6..3fbb406f 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -13,7 +13,7 @@ CustomCallCompiler, ) from guppy.error import GuppyTypeError, GuppyError -from guppy.types import GuppyType, FunctionType, BoolType +from guppy.gtypes import GuppyType, FunctionType, BoolType from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV from guppy.nodes import GlobalCall diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index a47ca48c..25dfc741 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -4,7 +4,7 @@ from guppy.custom import NoopCompiler, DefaultCallChecker from guppy.decorator import guppy -from guppy.types import BoolType +from guppy.gtypes import BoolType from guppy.hugr import tys, ops from guppy.module import GuppyModule from guppy.prelude._internal import ( diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index 73b7ee2d..e1efaab9 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -1,4 +1,4 @@ -from guppy.types import FunctionType, BoolType, TupleType +from guppy.gtypes import FunctionType, BoolType, TupleType from guppy.hugr import ops from guppy.hugr.hugr import Hugr diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py index dd9fd7a0..77e07ef4 100644 --- a/tests/hugr/test_ports.py +++ b/tests/hugr/test_ports.py @@ -1,7 +1,7 @@ import pytest from guppy.error import UndefinedPort, InternalGuppyError -from guppy.types import BoolType +from guppy.gtypes import BoolType def test_undefined_port(): From f7be21d3e1dc3ea9cfe5043e2e72e99d2f581a60 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 10:57:38 +0000 Subject: [PATCH 20/77] Split up into multiple PRs --- guppy/checker/cfg_checker.py | 157 +------ guppy/checker/expr_checker.py | 341 --------------- guppy/checker/func_checker.py | 139 +----- guppy/checker/stmt_checker.py | 126 ------ guppy/compiler/cfg_compiler.py | 164 +------- guppy/compiler/expr_compiler.py | 116 ----- guppy/compiler/func_compiler.py | 72 +--- guppy/compiler/stmt_compiler.py | 96 ----- guppy/custom.py | 32 +- guppy/declared.py | 9 +- guppy/prelude/_internal.py | 283 ------------- guppy/prelude/builtins.py | 725 -------------------------------- guppy/prelude/quantum.py | 35 +- 13 files changed, 30 insertions(+), 2265 deletions(-) delete mode 100644 guppy/checker/expr_checker.py delete mode 100644 guppy/checker/stmt_checker.py delete mode 100644 guppy/compiler/expr_compiler.py delete mode 100644 guppy/compiler/stmt_compiler.py delete mode 100644 guppy/prelude/_internal.py diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 1fbb0521..9cd0f41e 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -1,16 +1,11 @@ -import collections from dataclasses import dataclass from typing import Sequence -from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Globals, Context +from guppy.checker.core import Globals from guppy.checker.core import Variable -from guppy.checker.expr_checker import ExprSynthesizer, to_bool -from guppy.checker.stmt_checker import StmtChecker -from guppy.error import GuppyError from guppy.gtypes import GuppyType @@ -57,53 +52,7 @@ def check_cfg( Annotates the basic blocks with input and output type signatures and removes unreachable blocks. """ - # First, we need to run program analysis - ass_before = set(v.name for v in inputs) - cfg.analyze(ass_before, ass_before) - - # We start by compiling the entry BB - checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) - checked_cfg.entry_bb = check_bb( - cfg.entry_bb, checked_cfg, inputs, return_ty, globals - ) - compiled = {cfg.entry_bb: checked_cfg.entry_bb} - - # Visit all control-flow edges in BFS order. We can't just do a normal loop over - # all BBs since the input types for a BB are computed by checking a predecessor. - # We do BFS instead of DFS to get a better error ordering. - queue = collections.deque( - (checked_cfg.entry_bb, i, succ) - for i, succ in enumerate(cfg.entry_bb.successors) - ) - while len(queue) > 0: - pred, num_output, bb = queue.popleft() - input_row = [ - Variable(v.name, v.ty, v.defined_at, None) - for v in pred.sig.output_rows[num_output] - ] - - if bb in compiled: - # If the BB was already compiled, we just have to check that the signatures - # match. - check_rows_match(input_row, compiled[bb].sig.input_row, bb) - else: - # Otherwise, check the BB and enqueue its successors - checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) - queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)] - compiled[bb] = checked_bb - - # Link up BBs in the checked CFG - compiled[bb].predecessors.append(pred) - pred.successors[num_output] = compiled[bb] - - checked_cfg.bbs = list(compiled.values()) - checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable - checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} - checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} - checked_cfg.maybe_ass_before = { - compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs - } - return checked_cfg + raise NotImplementedError def check_bb( @@ -113,104 +62,4 @@ def check_bb( return_ty: GuppyType, globals: Globals, ) -> CheckedBB: - cfg = bb.cfg - - # For the entry BB we have to separately check that all used variables are - # defined. For all other BBs, this will be checked when compiling a predecessor. - if len(bb.predecessors) == 0: - for x, use in bb.vars.used.items(): - if x not in cfg.ass_before[bb] and x not in globals.values: - raise GuppyError(f"Variable `{x}` is not defined", use) - - # Check the basic block - ctx = Context(globals, {v.name: v for v in inputs}) - checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) - - # If we branch, we also have to check the branch predicate - if len(bb.successors) > 1: - assert bb.branch_pred is not None - bb.branch_pred, ty = ExprSynthesizer(ctx).synthesize(bb.branch_pred) - bb.branch_pred, _ = to_bool(bb.branch_pred, ty, ctx) - - for succ in bb.successors: - for x, use_bb in cfg.live_before[succ].items(): - # Check that the variables requested by the successor are defined - if x not in ctx.locals and x not in ctx.globals.values: - # If the variable is defined on *some* paths, we can give a more - # informative error message - if x in cfg.maybe_ass_before[use_bb]: - # TODO: This should be "Variable x is not defined when coming - # from {bb}". But for this we need a way to associate BBs with - # source locations. - raise GuppyError( - f"Variable `{x}` is not defined on all control-flow paths.", - use_bb.vars.used[x], - ) - raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x]) - - # We have to check that used linear variables are not being outputted - if x in ctx.locals: - var = ctx.locals[x] - if var.ty.linear and var.used: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - cfg.live_before[succ][x].vars.used[x], - [var.used], - ) - - # On the other hand, unused linear variables *must* be outputted - for x, var in ctx.locals.items(): - if var.ty.linear and not var.used and x not in cfg.live_before[succ]: - # TODO: This should be "Variable x with linear type ty is not - # used in {bb}". But for this we need a way to associate BBs with - # source locations. - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is " - "not used on all control-flow paths", - var.defined_at, - ) - - # Finally, we need to compute the signature of the basic block - outputs = [ - [ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals] - for succ in bb.successors - ] - - # Also prepare the successor list so we can fill it in later - checked_bb = CheckedBB( - bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) - ) - checked_bb.successors = [None] * len(bb.successors) # type: ignore - checked_bb.branch_pred = bb.branch_pred - return checked_bb - - -def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: - """Checks that the types of two rows match up. - - Otherwise, an error is thrown, alerting the user that a variable has different - types on different control-flow paths. - """ - map1, map2 = {v.name: v for v in row1}, {v.name: v for v in row2} - assert map1.keys() == map2.keys() - for x in map1: - v1, v2 = map1[x], map2[x] - if v1.ty != v2.ty: - # In the error message, we want to mention the variable that was first - # defined at the start. - if ( - v1.defined_at - and v2.defined_at - and line_col(v2.defined_at) < line_col(v1.defined_at) - ): - v1, v2 = v2, v1 - # We shouldn't mention temporary variables (starting with `%`) - # in error messages: - ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" - raise GuppyError( - f"{ident} can refer to different types: " - f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", - bb.cfg.live_before[bb][v1.name].vars.used[v1.name], - [v1.defined_at, v2.defined_at], - ) + raise NotImplementedError diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py deleted file mode 100644 index 747d32c0..00000000 --- a/guppy/checker/expr_checker.py +++ /dev/null @@ -1,341 +0,0 @@ -import ast -from typing import Optional, Union, NoReturn, Any - -from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt -from guppy.checker.core import Context, CallableVariable, Globals -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType -from guppy.nodes import LocalName, GlobalName, LocalCall - -# Mapping from unary AST op to dunder method and display name -unary_table: dict[type[ast.unaryop], tuple[str, str]] = { - ast.UAdd: ("__pos__", "+"), - ast.USub: ("__neg__", "-"), - ast.Invert: ("__invert__", "~"), -} - -# Mapping from binary AST op to left dunder method, right dunder method and display name -AstOp = Union[ast.operator, ast.cmpop] -binary_table: dict[type[AstOp], tuple[str, str, str]] = { - ast.Add: ("__add__", "__radd__", "+"), - ast.Sub: ("__sub__", "__rsub__", "-"), - ast.Mult: ("__mul__", "__rmul__", "*"), - ast.Div: ("__truediv__", "__rtruediv__", "/"), - ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), - ast.Mod: ("__mod__", "__rmod__", "%"), - ast.Pow: ("__pow__", "__rpow__", "**"), - ast.LShift: ("__lshift__", "__rlshift__", "<<"), - ast.RShift: ("__rshift__", "__rrshift__", ">>"), - ast.BitOr: ("__or__", "__ror__", "|"), - ast.BitXor: ("__xor__", "__rxor__", "^"), - ast.BitAnd: ("__and__", "__rand__", "&"), - ast.MatMult: ("__matmul__", "__rmatmul__", "@"), - ast.Eq: ("__eq__", "__eq__", "=="), - ast.NotEq: ("__neq__", "__neq__", "!="), - ast.Lt: ("__lt__", "__gt__", "<"), - ast.LtE: ("__le__", "__ge__", "<="), - ast.Gt: ("__gt__", "__lt__", ">"), - ast.GtE: ("__ge__", "__le__", ">="), -} - - -class ExprChecker(AstVisitor[ast.expr]): - """Checks an expression against a type and produces a new type-annotated AST""" - - ctx: Context - - # Name for the kind of term we are currently checking against (used in errors). - # For example, "argument", "return value", or in general "expression". - _kind: str - - def __init__(self, ctx: Context) -> None: - self.ctx = ctx - self._kind = "expression" - - def _fail( - self, - expected: GuppyType, - actual: Union[ast.expr, GuppyType], - loc: Optional[AstNode] = None, - ) -> NoReturn: - """Raises a type error indicating that the type doesn't match.""" - if not isinstance(actual, GuppyType): - loc = loc or actual - _, actual = self._synthesize(actual) - assert loc is not None - raise GuppyTypeError( - f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc - ) - - def check( - self, expr: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: - """Checks an expression against a type. - - Returns a new desugared expression with type annotations. - """ - old_kind = self._kind - self._kind = kind or self._kind - expr = self.visit(expr, ty) - self._kind = old_kind - return with_type(ty, expr) - - def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: - """Invokes the type synthesiser""" - return ExprSynthesizer(self.ctx).synthesize(node) - - def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: - if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): - return self._fail(ty, node) - for i, el in enumerate(node.elts): - node.elts[i] = self.check(el, ty.element_types[i]) - return node - - def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore - # Try to synthesize and then check if it matches the given type - node, synth = self._synthesize(node) - if synth != ty: - self._fail(ty, synth, node) - return node - - -class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): - ctx: Context - - def __init__(self, ctx: Context) -> None: - self.ctx = ctx - - def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: - """Tries to synthesise a type for the given expression. - - Also returns a new desugared expression with type annotations. - """ - if ty := get_type_opt(node): - return node, ty - node, ty = self.visit(node) - return with_type(ty, node), ty - - def _check( - self, expr: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: - """Checks an expression against a given type""" - return ExprChecker(self.ctx).check(expr, ty, kind) - - def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: - ty = python_value_to_guppy_type(node.value, node, self.ctx.globals) - if ty is None: - raise GuppyError("Unsupported constant", node) - return node, ty - - def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: - x = node.id - if x in self.ctx.locals: - var = self.ctx.locals[x] - if var.ty.linear and var.used is not None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - node, - [var.used], - ) - var.used = node - return with_loc(node, LocalName(id=x)), var.ty - elif x in self.ctx.globals.values: - # Cache value in the AST - value = self.ctx.globals.values[x] - return with_loc(node, GlobalName(id=x, value=value)), value.ty - raise InternalGuppyError( - f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " - f"been caught by program analysis!" - ) - - def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: - elems = [self.synthesize(elem) for elem in node.elts] - node.elts = [n for n, _ in elems] - return node, TupleType([ty for _, ty in elems]) - - def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: - # We need to synthesise the argument type, so we can look up dunder methods - node.operand, op_ty = self.synthesize(node.operand) - - # Special case for the `not` operation since it is not implemented via a dunder - # method or control-flow - if isinstance(node.op, ast.Not): - node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx) - return node, bool_ty - - # Check all other unary expressions by calling out to instance dunder methods - op, display_name = unary_table[node.op.__class__] - func = self.ctx.globals.get_instance_func(op_ty, op) - if func is None: - raise GuppyTypeError( - f"Unary operator `{display_name}` not defined for argument of type " - f" `{op_ty}`", - node.operand, - ) - return func.synthesize_call([node.operand], node, self.ctx) - - def _synthesize_binary( - self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr - ) -> tuple[ast.expr, GuppyType]: - """Helper method to compile binary operators by calling out to dunder methods. - - For example, first try calling `__add__` on the left operand. If that fails, try - `__radd__` on the right operand. - """ - if op.__class__ not in binary_table: - raise GuppyError("This binary operation is not supported by Guppy.", op) - lop, rop, display_name = binary_table[op.__class__] - left_expr, left_ty = self.synthesize(left_expr) - right_expr, right_ty = self.synthesize(right_expr) - - if func := self.ctx.globals.get_instance_func(left_ty, lop): - try: - return func.synthesize_call([left_expr, right_expr], node, self.ctx) - except GuppyError: - pass - - if func := self.ctx.globals.get_instance_func(right_ty, rop): - try: - return func.synthesize_call([right_expr, left_expr], node, self.ctx) - except GuppyError: - pass - - raise GuppyTypeError( - f"Binary operator `{display_name}` not defined for arguments of type " - f"`{left_ty}` and `{right_ty}`", - node, - ) - - def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: - return self._synthesize_binary(node.left, node.right, node.op, node) - - def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: - if len(node.comparators) != 1 or len(node.ops) != 1: - raise InternalGuppyError( - "BB contains chained comparison. Should have been removed during CFG " - "construction." - ) - left_expr, [op], [right_expr] = node.left, node.ops, node.comparators - return self._synthesize_binary(left_expr, right_expr, op, node) - - def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: - if len(node.keywords) > 0: - raise GuppyError( - "Argument passing by keyword is not supported", node.keywords[0] - ) - node.func, ty = self.synthesize(node.func) - - # First handle direct calls of user-defined functions and extension functions - if isinstance(node.func, GlobalName) and isinstance( - node.func.value, CallableVariable - ): - return node.func.value.synthesize_call(node.args, node, self.ctx) - - # Otherwise, it must be a function as a higher-order value - if isinstance(ty, FunctionType): - args, return_ty = synthesize_call(ty, node.args, node, self.ctx) - return with_loc(node, LocalCall(func=node.func, args=args)), return_ty - elif f := self.ctx.globals.get_instance_func(ty, "__call__"): - return f.synthesize_call(node.args, node, self.ctx) - else: - raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) - - def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: - raise InternalGuppyError( - "BB contains `NamedExpr`. Should have been removed during CFG" - f"construction: `{ast.unparse(node)}`" - ) - - def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, GuppyType]: - raise InternalGuppyError( - "BB contains `BoolOp`. Should have been removed during CFG construction: " - f"`{ast.unparse(node)}`" - ) - - def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: - raise InternalGuppyError( - "BB contains `IfExp`. Should have been removed during CFG construction: " - f"`{ast.unparse(node)}`" - ) - - -def check_num_args(exp: int, act: int, node: AstNode) -> None: - """Checks that the correct number of arguments have been passed to a function.""" - if act < exp: - raise GuppyTypeError( - f"Not enough arguments passed (expected {exp}, got {act})", node - ) - if exp < act: - if isinstance(node, ast.Call): - raise GuppyTypeError("Unexpected argument", node.args[exp]) - raise GuppyTypeError( - f"Too many arguments passed (expected {exp}, got {act})", node - ) - - -def synthesize_call( - func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context -) -> tuple[list[ast.expr], GuppyType]: - """Synthesizes the return type of a function call""" - check_num_args(len(func_ty.args), len(args), node) - for i, arg in enumerate(args): - args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") - return args, func_ty.returns - - -def check_call( - func_ty: FunctionType, - args: list[ast.expr], - ty: GuppyType, - node: AstNode, - ctx: Context, -) -> list[ast.expr]: - """Checks the return type of a function call against a given type""" - args, return_ty = synthesize_call(func_ty, args, node, ctx) - if return_ty != ty: - raise GuppyTypeError( - f"Expected expression of type `{ty}`, got `{return_ty}`", node - ) - return args - - -def to_bool( - node: ast.expr, node_ty: GuppyType, ctx: Context -) -> tuple[ast.expr, GuppyType]: - """Tries to turn a node into a bool""" - if isinstance(node_ty, BoolType): - return node, node_ty - - func = ctx.globals.get_instance_func(node_ty, "__bool__") - if func is None: - raise GuppyTypeError( - f"Expression of type `{node_ty}` cannot be interpreted as a `bool`", - node, - ) - - # We could check the return type against bool, but we can give a better error - # message if we synthesise and compare to bool by hand - call, return_ty = func.synthesize_call([node], node, ctx) - if not isinstance(return_ty, BoolType): - raise GuppyTypeError( - f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", - node, - ) - return call, return_ty - - -def python_value_to_guppy_type( - v: Any, node: ast.expr, globals: Globals -) -> Optional[GuppyType]: - """Turns a primitive Python value into a Guppy type. - - Returns `None` if the Python value cannot be represented in Guppy. - """ - if isinstance(v, bool): - return globals.types["bool"].build(node=node) - elif isinstance(v, int): - return globals.types["int"].build(node=node) - if isinstance(v, float): - return globals.types["float"].build(node=node) - return None diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 87b22b8f..cb2d3234 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -1,14 +1,11 @@ import ast from dataclasses import dataclass -from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc +from guppy.ast_util import AstNode from guppy.cfg.bb import BB -from guppy.cfg.builder import CFGBuilder -from guppy.checker.core import Variable, Globals, Context, CallableVariable -from guppy.checker.cfg_checker import check_cfg, CheckedCFG -from guppy.checker.expr_checker import synthesize_call, check_call -from guppy.error import GuppyError -from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType +from guppy.checker.core import Globals, Context, CallableVariable +from guppy.checker.cfg_checker import CheckedCFG +from guppy.gtypes import FunctionType, GuppyType from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef @@ -29,16 +26,12 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context ) -> GlobalCall: - # Use default implementation from the expression checker - args = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args) + raise NotImplementedError def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: - # Use default implementation from the expression checker - args, ty = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args), ty + raise NotImplementedError @dataclass @@ -50,131 +43,17 @@ class CheckedFunction(DefinedFunction): def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFunction: """Type checks a top-level function definition.""" - func_def = func.defined_at - args = func_def.args.args - returns_none = isinstance(func.ty.returns, NoneType) - assert func.ty.arg_names is not None - - cfg = CFGBuilder().build(func_def.body, returns_none, globals) - inputs = [ - Variable(x, ty, loc, None) - for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) - ] - cfg = check_cfg(cfg, inputs, func.ty.returns, globals) - return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) + raise NotImplementedError def check_nested_func_def( func_def: NestedFunctionDef, bb: BB, ctx: Context ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" - func_ty = check_signature(func_def, ctx.globals) - assert func_ty.arg_names is not None - - # We've already built the CFG for this function while building the CFG of the - # enclosing function - cfg = func_def.cfg - - # Find captured variables - parent_cfg = bb.cfg - def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() - maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] - cfg.analyze(def_ass_before, maybe_ass_before) - captured = { - x: ctx.locals[x] - for x in cfg.live_before[cfg.entry_bb] - if x not in func_ty.arg_names and x in ctx.locals - } - - # Captured variables may not be linear - for v in captured.values(): - if v.ty.linear: - x = v.name - using_bb = cfg.live_before[cfg.entry_bb][x] - raise GuppyError( - f"Variable `{x}` with linear type `{v.ty}` may not be used here " - f"because it was defined in an outer scope (at {{0}})", - using_bb.vars.used[x], - [v.defined_at], - ) - - # Captured variables may never be assigned to - for bb in cfg.bbs: - for v in captured.values(): - x = v.name - if x in bb.vars.assigned: - raise GuppyError( - f"Variable `{x}` defined in an outer scope (at {{0}}) may not " - f"be assigned to", - bb.vars.assigned[x], - [v.defined_at], - ) - - # Construct inputs for checking the body CFG - inputs = list(captured.values()) + [ - Variable(x, ty, func_def.args.args[i], None) - for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) - ] - globals = ctx.globals - - # Check if the body contains a recursive occurrence of the function name - if func_def.name in cfg.live_before[cfg.entry_bb]: - if len(captured) == 0: - # If there are no captured vars, we treat the function like a global name - func = DefinedFunction(func_def.name, func_ty, func_def, None) - globals = ctx.globals | Globals({func_def.name: func}, {}) - - else: - # Otherwise, we treat it like a local name - inputs.append(Variable(func_def.name, func_def.ty, func_def, None)) - - checked_cfg = check_cfg(cfg, inputs, func_ty.returns, globals) - checked_def = CheckedNestedFunctionDef( - checked_cfg, - func_ty, - captured, - name=func_def.name, - args=func_def.args, - body=func_def.body, - decorator_list=func_def.decorator_list, - returns=func_def.returns, - type_comment=func_def.type_comment, - ) - return with_loc(func_def, checked_def) + raise NotImplementedError def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: """Checks the signature of a function definition and returns the corresponding Guppy type.""" - if len(func_def.args.posonlyargs) != 0: - raise GuppyError( - "Positional-only parameters not supported", func_def.args.posonlyargs[0] - ) - if len(func_def.args.kwonlyargs) != 0: - raise GuppyError( - "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] - ) - if func_def.args.vararg is not None: - raise GuppyError("*args not supported", func_def.args.vararg) - if func_def.args.kwarg is not None: - raise GuppyError("**kwargs not supported", func_def.args.kwarg) - if func_def.returns is None: - # TODO: Error location is incorrect - if all(r.value is None for r in return_nodes_in_ast(func_def)): - raise GuppyError( - "Return type must be annotated. Try adding a `-> None` annotation.", - func_def, - ) - raise GuppyError("Return type must be annotated", func_def) - - arg_tys = [] - arg_names = [] - for i, arg in enumerate(func_def.args.args): - if arg.annotation is None: - raise GuppyError("Argument type must be annotated", arg) - ty = type_from_ast(arg.annotation, globals) - arg_tys.append(ty) - arg_names.append(arg.arg) - - ret_type = type_from_ast(func_def.returns, globals) - return FunctionType(arg_tys, ret_type, arg_names) + raise NotImplementedError diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py deleted file mode 100644 index 5b12dc58..00000000 --- a/guppy/checker/stmt_checker.py +++ /dev/null @@ -1,126 +0,0 @@ -import ast -from typing import Sequence - -from guppy.ast_util import with_loc, AstVisitor -from guppy.cfg.bb import BB, BBStatement -from guppy.checker.core import Variable, Context -from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType -from guppy.nodes import NestedFunctionDef - - -class StmtChecker(AstVisitor[BBStatement]): - ctx: Context - bb: BB - return_ty: GuppyType - - def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: - self.ctx = ctx - self.bb = bb - self.return_ty = return_ty - - def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: - return [self.visit(s) for s in stmts] - - def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: - return ExprSynthesizer(self.ctx).synthesize(node) - - def _check_expr( - self, node: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: - return ExprChecker(self.ctx).check(node, ty, kind) - - def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: - """Helper function to check assignments with patterns.""" - # Easiest case is if the LHS pattern is a single variable. - if isinstance(lhs, ast.Name): - # Check if we override an unused linear variable - x = lhs.id - if x in self.ctx.locals: - var = self.ctx.locals[x] - if var.ty.linear and var.used is None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is not used", - var.defined_at, - ) - self.ctx.locals[x] = Variable(x, ty, node, None) - # The only other thing we support right now are tuples - elif isinstance(lhs, ast.Tuple): - tys = ty.element_types if isinstance(ty, TupleType) else [ty] - n, m = len(lhs.elts), len(tys) - if n != m: - raise GuppyTypeError( - f"{'Too many' if n < m else 'Not enough'} values to unpack " - f"(expected {n}, got {m})", - node, - ) - for pat, el_ty in zip(lhs.elts, tys): - self._check_assign(pat, el_ty, node) - # TODO: Python also supports assignments like `[a, b] = [1, 2]` or - # `a, *b = ...`. The former would require some runtime checks but - # the latter should be easier to do (unpack and repack the rest). - else: - raise GuppyError("Assignment pattern not supported", lhs) - - def visit_Assign(self, node: ast.Assign) -> ast.stmt: - if len(node.targets) > 1: - # This is the case for assignments like `a = b = 1` - raise GuppyError("Multi assignment not supported", node) - - [target] = node.targets - node.value, ty = self._synth_expr(node.value) - self._check_assign(target, ty, node) - return node - - def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: - if node.value is None: - raise GuppyError( - "Variable declaration is not supported. Assignment is required", node - ) - ty = type_from_ast(node.annotation, self.ctx.globals) - node.value = self._check_expr(node.value, ty) - self._check_assign(node.target, ty, node) - return node - - def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: - bin_op = with_loc( - node, ast.BinOp(left=node.target, op=node.op, right=node.value) - ) - assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) - return self.visit_Assign(assign) - - def visit_Expr(self, node: ast.Expr) -> ast.stmt: - # An expression statement where the return value is discarded - node.value, _ = self._synth_expr(node.value) - return node - - def visit_Return(self, node: ast.Return) -> ast.stmt: - if node.value is not None: - node.value = self._check_expr(node.value, self.return_ty, "return value") - elif not isinstance(self.return_ty, NoneType): - raise GuppyTypeError( - f"Expected return value of type `{self.return_ty}`", None - ) - return node - - def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: - from guppy.checker.func_checker import check_nested_func_def - - func_def = check_nested_func_def(node, self.bb, self.ctx) - self.ctx.locals[func_def.name] = Variable( - func_def.name, func_def.ty, func_def, None - ) - return func_def - - def visit_If(self, node: ast.If) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_While(self, node: ast.While) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_Break(self, node: ast.Break) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") - - def visit_Continue(self, node: ast.Continue) -> None: - raise InternalGuppyError("Control-flow statement should not be present here.") diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index d18c21b9..cf4717eb 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -1,171 +1,17 @@ -import functools -from typing import Sequence - -from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature -from guppy.checker.core import Variable -from guppy.compiler.core import ( - CompiledGlobals, - is_return_var, - DFContainer, - return_var, - PortVariable, -) -from guppy.compiler.expr_compiler import ExprCompiler -from guppy.compiler.stmt_compiler import StmtCompiler -from guppy.gtypes import TupleType, SumType, type_to_row -from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV +from guppy.checker.cfg_checker import CheckedBB, CheckedCFG +from guppy.compiler.core import CompiledGlobals +from guppy.hugr.hugr import Hugr, Node, CFNode def compile_cfg( cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals ) -> None: """Compiles a CFG to Hugr.""" - insert_return_vars(cfg) - - blocks: dict[CheckedBB, CFNode] = {} - for bb in cfg.bbs: - blocks[bb] = compile_bb(bb, graph, parent, globals) - for bb in cfg.bbs: - for succ in bb.successors: - graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) + raise NotImplementedError def compile_bb( bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals ) -> CFNode: """Compiles a single basic block to Hugr.""" - inputs = sort_vars(bb.sig.input_row) - - # The exit BB is completely empty - if len(bb.successors) == 0: - assert len(bb.statements) == 0 - return graph.add_exit([v.ty for v in inputs], parent) - - # Otherwise, we use a regular `Block` node - block = graph.add_block(parent) - - # Add input node and compile the statements - inp = graph.add_input(output_tys=[v.ty for v in inputs], parent=block) - dfg = DFContainer( - block, - { - v.name: PortVariable(v.name, inp.out_port(i), v.defined_at, None) - for (i, v) in enumerate(inputs) - }, - ) - dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) - - # If we branch, we also have to compile the branch predicate - if len(bb.successors) > 1: - assert bb.branch_pred is not None - branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg) - else: - # Even if we don't branch, we still have to add a `Sum(())` predicates - unit = graph.add_make_tuple([], parent=block).out_port(0) - branch_port = graph.add_tag( - variants=[TupleType([])], tag=0, inp=unit, parent=block - ).out_port(0) - - # Finally, we have to add the block output. - outputs: Sequence[Variable] - if len(bb.successors) == 1: - # The easy case is if we don't branch: We just output all variables that are - # specified by the signature - [outputs] = bb.sig.output_rows - else: - # If we branch and the branches use the same variables, then we can use a - # regular output - first, *rest = bb.sig.output_rows - if all({v.name for v in first} == {v.name for v in r} for r in rest): - outputs = first - else: - # Otherwise, we have to output a Sum-type predicate: We put all non-linear - # variables into the branch predicate and all linear variables in the normal - # output (since they are shared between all successors). This is in line - # with the ordering on variables which puts linear variables at the end. - # The only exception are return vars which must be outputted in order. - branch_port = choose_vars_for_pred( - graph=graph, - pred=branch_port, - output_vars=[ - [v for v in row if not v.ty.linear or is_return_var(v.name)] - for row in bb.sig.output_rows - ], - dfg=dfg, - ) - outputs = [v for v in first if v.ty.linear and not is_return_var(v.name)] - - graph.add_output( - inputs=[branch_port] + [dfg[v.name].port for v in sort_vars(outputs)], - parent=block, - ) - return block - - -def insert_return_vars(cfg: CheckedCFG) -> None: - """Patches a CFG by annotating dummy return variables in the BB signatures. - - The statement compiler turns `return` statements into assignments of dummy variables - `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are - correctly outputted. - """ - return_vars = [ - Variable(return_var(i), ty, None, None) - for i, ty in enumerate(type_to_row(cfg.output_ty)) - ] - # Before patching, the exit BB shouldn't take any inputs - assert len(cfg.exit_bb.sig.input_row) == 0 - cfg.exit_bb.sig = Signature(return_vars, cfg.exit_bb.sig.output_rows) - # Also patch the predecessors - for pred in cfg.exit_bb.predecessors: - # The exit BB will be the only successor - assert len(pred.sig.output_rows) == 1 and len(pred.sig.output_rows[0]) == 0 - pred.sig = Signature(pred.sig.input_row, [return_vars]) - - -def choose_vars_for_pred( - graph: Hugr, pred: OutPortV, output_vars: list[VarRow], dfg: DFContainer -) -> OutPortV: - """Selects an output based on a predicate. - - Given `pred: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, constructs - a predicate value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. - """ - assert isinstance(pred.ty, SumType) - assert len(pred.ty.element_types) == len(output_vars) - tuples = [ - graph.add_make_tuple( - inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], - parent=dfg.node, - ).out_port(0) - for vs in output_vars - ] - tys = [t.ty for t in tuples] - conditional = graph.add_conditional(cond_input=pred, inputs=tuples, parent=dfg.node) - for i, ty in enumerate(tys): - case = graph.add_case(conditional) - inp = graph.add_input(output_tys=tys, parent=case).out_port(i) - tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) - graph.add_output(inputs=[tag], parent=case) - return conditional.add_out_port(SumType(tys)) - - -def compare_var(x: Variable, y: Variable) -> int: - """Defines a `<` order on variables. - - We use this to determine in which order variables are outputted from basic blocks. - We need to output linear variables at the end, so we do a lexicographic ordering of - linearity and name. The only exception are return vars which must be outputted in - order. - """ - if is_return_var(x.name) and is_return_var(y.name): - return -1 if x.name < y.name else 1 - return -1 if (x.ty.linear, x.name) < (y.ty.linear, y.name) else 1 - - -def sort_vars(row: VarRow) -> list[Variable]: - """Sorts a row of variables. - - This determines the order in which they are outputted from a BB. - """ - return sorted(row, key=functools.cmp_to_key(compare_var)) + raise NotImplementedError diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py deleted file mode 100644 index 5fdc378b..00000000 --- a/guppy/compiler/expr_compiler.py +++ /dev/null @@ -1,116 +0,0 @@ -import ast -from typing import Any, Optional - -from guppy.ast_util import AstVisitor, get_type -from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction -from guppy.error import InternalGuppyError -from guppy.gtypes import FunctionType, type_to_row, BoolType -from guppy.hugr import ops, val -from guppy.hugr.hugr import OutPortV -from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall - - -class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): - """A compiler from Guppy expressions to Hugr.""" - - dfg: DFContainer - - def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: - """Compiles an expression and returns a single port holding the output value.""" - self.dfg = dfg - with self.graph.parent(dfg.node): - res = self.visit(expr) - return res - - def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: - """Compiles a row expression and returns a list of ports, one for each value in - the row. - - On Python-level, we treat tuples like rows on top-level. However, nested tuples - are treated like regular Guppy tuples. - """ - return [self.compile(e, dfg) for e in expr_to_row(expr)] - - def visit_Constant(self, node: ast.Constant) -> OutPortV: - if value := python_value_to_hugr(node.value): - const = self.graph.add_constant(value, get_type(node)).out_port(0) - return self.graph.add_load_constant(const).out_port(0) - raise InternalGuppyError("Unsupported constant expression in compiler") - - def visit_LocalName(self, node: LocalName) -> OutPortV: - return self.dfg[node.id].port - - def visit_GlobalName(self, node: GlobalName) -> OutPortV: - return self.globals[node.id].load(self.dfg, self.graph, self.globals, node) - - def visit_Name(self, node: ast.Name) -> OutPortV: - raise InternalGuppyError("Node should have been removed during type checking.") - - def visit_Tuple(self, node: ast.Tuple) -> OutPortV: - return self.graph.add_make_tuple( - inputs=[self.visit(e) for e in node.elts] - ).out_port(0) - - def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: - """Groups function return values into a tuple""" - if len(returns) != 1: - return self.graph.add_make_tuple(inputs=returns).out_port(0) - return returns[0] - - def visit_LocalCall(self, node: LocalCall) -> OutPortV: - func = self.visit(node.func) - assert isinstance(func.ty, FunctionType) - - args = [self.visit(arg) for arg in node.args] - call = self.graph.add_indirect_call(func, args) - rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.returns)))] - return self._pack_returns(rets) - - def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: - func = self.globals[node.func.name] - assert isinstance(func, CompiledFunction) - - args = [self.visit(arg) for arg in node.args] - rets = func.compile_call(args, self.dfg, self.graph, self.globals, node) - return self._pack_returns(rets) - - def visit_Call(self, node: ast.Call) -> OutPortV: - raise InternalGuppyError("Node should have been removed during type checking.") - - def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: - # The only case that is not desugared by the type checker is the `not` operation - # since it is not implemented via a dunder method - if isinstance(node.op, ast.Not): - arg = self.visit(node.operand) - return self.graph.add_node( - ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] - ).add_out_port(BoolType()) - - raise InternalGuppyError("Node should have been removed during type checking.") - - def visit_BinOp(self, node: ast.BinOp) -> OutPortV: - raise InternalGuppyError("Node should have been removed during type checking.") - - def visit_Compare(self, node: ast.Compare) -> OutPortV: - raise InternalGuppyError("Node should have been removed during type checking.") - - -def expr_to_row(expr: ast.expr) -> list[ast.expr]: - """Turns an expression into a row expressions by unpacking top-level tuples.""" - return expr.elts if isinstance(expr, ast.Tuple) else [expr] - - -def python_value_to_hugr(v: Any) -> Optional[val.Value]: - """Turns a Python value into a Hugr value. - - Returns None if the Python value cannot be represented in Guppy. - """ - from guppy.prelude._internal import int_value, bool_value, float_value - - if isinstance(v, bool): - return bool_value(v) - elif isinstance(v, int): - return int_value(v) - elif isinstance(v, float): - return float_value(v) - return None diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index ab066839..bdc102df 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -2,14 +2,12 @@ from guppy.ast_util import AstNode from guppy.checker.func_checker import CheckedFunction, DefinedFunction -from guppy.compiler.cfg_compiler import compile_cfg from guppy.compiler.core import ( CompiledFunction, CompiledGlobals, DFContainer, PortVariable, ) -from guppy.gtypes import type_to_row, FunctionType from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef @@ -21,7 +19,7 @@ class CompiledFunctionDef(DefinedFunction, CompiledFunction): def load( self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode ) -> OutPortV: - return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) + raise NotImplementedError def compile_call( self, @@ -31,8 +29,7 @@ def compile_call( globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - call = graph.add_call(self.node.out_port(0), args, dfg.node) - return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] + raise NotImplementedError def compile_global_func_def( @@ -42,19 +39,7 @@ def compile_global_func_def( globals: CompiledGlobals, ) -> CompiledFunctionDef: """Compiles a top-level function definition to Hugr.""" - def_input = graph.add_input(parent=def_node) - cfg_node = graph.add_cfg( - def_node, inputs=[def_input.add_out_port(ty) for ty in func.ty.args] - ) - compile_cfg(func.cfg, graph, cfg_node, globals) - - # Add output node for the cfg - graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], - parent=def_node, - ) - - return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) + raise NotImplementedError def compile_local_func_def( @@ -64,53 +49,4 @@ def compile_local_func_def( globals: CompiledGlobals, ) -> PortVariable: """Compiles a local (nested) function definition to Hugr.""" - assert func.ty.arg_names is not None - - # Pick an order for the captured variables - captured = list(func.captured.values()) - - # Prepend captured variables to the function arguments - closure_ty = FunctionType( - [v.ty for v in captured] + list(func.ty.args), - func.ty.returns, - [v.name for v in captured] + list(func.ty.arg_names), - ) - - def_node = graph.add_def(closure_ty, dfg.node, func.name) - def_input = graph.add_input(parent=def_node) - input_ports = [def_input.add_out_port(ty) for ty in closure_ty.args] - - # If we have captured variables and the body contains a recursive occurrence of - # the function itself, then we provide the partially applied function as a local - # variable - if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: - loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0) - partial = graph.add_partial( - loaded, [def_input.out_port(i) for i in range(len(captured))], def_node - ) - input_ports.append(partial.out_port(0)) - func.cfg.input_tys.append(func.ty) - else: - # Otherwise, we treat the function like a normal global variable - globals = globals | { - func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node) - } - - # Compile the CFG - cfg_node = graph.add_cfg(def_node, inputs=input_ports) - compile_cfg(func.cfg, graph, cfg_node, globals) - - # Add output node for the cfg - graph.add_output( - inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], - parent=def_node, - ) - - # Finally, load the function into the local data-flow graph - loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) - if len(captured) > 0: - loaded = graph.add_partial( - loaded, [dfg[v.name].port for v in captured], dfg.node - ).out_port(0) - - return PortVariable(func.name, loaded, func) + raise NotImplementedError diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py deleted file mode 100644 index db656b6f..00000000 --- a/guppy/compiler/stmt_compiler.py +++ /dev/null @@ -1,96 +0,0 @@ -import ast -from typing import Sequence - -from guppy.ast_util import AstVisitor -from guppy.checker.cfg_checker import CheckedBB -from guppy.compiler.core import ( - CompilerBase, - DFContainer, - CompiledGlobals, - PortVariable, - return_var, -) -from guppy.compiler.expr_compiler import ExprCompiler -from guppy.error import InternalGuppyError -from guppy.gtypes import TupleType -from guppy.hugr.hugr import Hugr, OutPortV -from guppy.nodes import CheckedNestedFunctionDef - - -class StmtCompiler(CompilerBase, AstVisitor[None]): - """A compiler for Guppy statements to Hugr""" - - expr_compiler: ExprCompiler - - bb: CheckedBB - dfg: DFContainer - - def __init__(self, graph: Hugr, globals: CompiledGlobals): - super().__init__(graph, globals) - self.expr_compiler = ExprCompiler(graph, globals) - - def compile_stmts( - self, - stmts: Sequence[ast.stmt], - bb: CheckedBB, - dfg: DFContainer, - ) -> DFContainer: - """Compiles a list of basic statements into a dataflow node. - - Note that the `dfg` is mutated in-place. After compilation, the DFG will also - contain all variables that are assigned in the given list of statements. - """ - self.bb = bb - self.dfg = dfg - for s in stmts: - self.visit(s) - return self.dfg - - def _unpack_assign(self, lhs: ast.expr, port: OutPortV, node: ast.stmt) -> None: - """Updates the local DFG with assignments.""" - if isinstance(lhs, ast.Name): - x = lhs.id - self.dfg[x] = PortVariable(x, port, node) - elif isinstance(lhs, ast.Tuple): - unpack = self.graph.add_unpack_tuple(port, self.dfg.node) - for i, pat in enumerate(lhs.elts): - self._unpack_assign(pat, unpack.out_port(i), node) - else: - raise InternalGuppyError("Invalid assign pattern in compiler") - - def visit_Assign(self, node: ast.Assign) -> None: - [target] = node.targets - port = self.expr_compiler.compile(node.value, self.dfg) - self._unpack_assign(target, port, node) - - def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - assert node.value is not None - port = self.expr_compiler.compile(node.value, self.dfg) - self._unpack_assign(node.target, port, node) - - def visit_AugAssign(self, node: ast.AugAssign) -> None: - raise InternalGuppyError("Node should have been removed during type checking.") - - def visit_Expr(self, node: ast.Expr) -> None: - self.expr_compiler.compile_row(node.value, self.dfg) - - def visit_Return(self, node: ast.Return) -> None: - # We turn returns into assignments of dummy variables, i.e. the statement - # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`. - if node.value is not None: - port = self.expr_compiler.compile(node.value, self.dfg) - if isinstance(port.ty, TupleType): - unpack = self.graph.add_unpack_tuple(port, self.dfg.node) - row = [unpack.out_port(i) for i in range(len(port.ty.element_types))] - else: - row = [port] - for i, port in enumerate(row): - name = return_var(i) - self.dfg[name] = PortVariable(name, port, node) - - def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: - from guppy.compiler.func_compiler import compile_local_func_def - - self.dfg[node.name] = compile_local_func_def( - node, self.dfg, self.graph, self.globals - ) diff --git a/guppy/custom.py b/guppy/custom.py index acfd7341..4027232f 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -2,22 +2,19 @@ from abc import ABC, abstractmethod from typing import Optional -from guppy.ast_util import AstNode, get_type, with_type, with_loc +from guppy.ast_util import AstNode, with_type, with_loc from guppy.checker.core import Context, Globals -from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals -from guppy.compiler.expr_compiler import ExprCompiler from guppy.error import ( GuppyError, InternalGuppyError, UnknownFunctionType, GuppyTypeError, ) -from guppy.gtypes import GuppyType, FunctionType, type_to_row, TupleType +from guppy.gtypes import GuppyType, FunctionType from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode -from guppy.nodes import GlobalCall class CustomFunction(CompiledFunction): @@ -210,26 +207,17 @@ class DefaultCallChecker(CustomCallChecker): """Checks function calls by comparing to a type signature.""" def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - # Use default implementation from the expression checker - args = check_call(self.func.ty, args, ty, self.node, self.ctx) - return GlobalCall(func=self.func, args=args) + raise NotImplementedError def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - # Use default implementation from the expression checker - args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) - return GlobalCall(func=self.func, args=args), ty + raise NotImplementedError class DefaultCallCompiler(CustomCallCompiler): """Call compiler that invokes the regular expression compiler.""" def compile(self, args: list[OutPortV]) -> list[OutPortV]: - assert isinstance(self.node, ast.expr) - ret = ExprCompiler(self.graph, self.globals).compile(self.node, self.dfg) - if isinstance(ret.ty, TupleType): - unpack = self.graph.add_unpack_tuple(ret, self.dfg.node) - return [unpack.out_port(i) for i in range(len(ret.ty.element_types))] - return [ret] + raise NotImplementedError class OpCompiler(CustomCallCompiler): @@ -239,12 +227,4 @@ def __init__(self, op: ops.OpType) -> None: self.op = op def compile(self, args: list[OutPortV]) -> list[OutPortV]: - node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) - return_ty = get_type(self.node) - assert return_ty is not None - return [node.add_out_port(ty) for ty in type_to_row(return_ty)] - - -class NoopCompiler(CustomCallCompiler): - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - return args + raise NotImplementedError diff --git a/guppy/declared.py b/guppy/declared.py index 91b5283e..0106ed21 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -4,7 +4,6 @@ from guppy.ast_util import AstNode, has_empty_body from guppy.checker.core import Globals, Context -from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError @@ -33,16 +32,12 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context ) -> GlobalCall: - # Use default implementation from the expression checker - args = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args) + raise NotImplementedError def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: - # Use default implementation from the expression checker - args, ty = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args), ty + raise NotImplementedError def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py deleted file mode 100644 index 3fbb406f..00000000 --- a/guppy/prelude/_internal.py +++ /dev/null @@ -1,283 +0,0 @@ -import ast -from typing import Optional, Literal - -from pydantic import BaseModel - -from guppy.ast_util import with_type, AstNode, with_loc, get_type -from guppy.checker.core import Context, CallableVariable -from guppy.checker.expr_checker import ExprSynthesizer, check_num_args -from guppy.custom import ( - CustomCallChecker, - DefaultCallChecker, - CustomFunction, - CustomCallCompiler, -) -from guppy.error import GuppyTypeError, GuppyError -from guppy.gtypes import GuppyType, FunctionType, BoolType -from guppy.hugr import ops, tys, val -from guppy.hugr.hugr import OutPortV -from guppy.nodes import GlobalCall - - -INT_WIDTH = 6 # 2^6 = 64 bit - - -hugr_int_type = tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=INT_WIDTH)], - bound=tys.TypeBound.Eq, -) - - -hugr_float_type = tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=tys.TypeBound.Copyable, -) - - -class ConstIntS(BaseModel): - """Hugr representation of signed integers in the arithmetic extension.""" - - c: Literal["ConstIntS"] = "ConstIntS" - log_width: int - value: int - - -class ConstF64(BaseModel): - """Hugr representation of floats in the arithmetic extension.""" - - c: Literal["ConstF64"] = "ConstF64" - value: float - - -def bool_value(b: bool) -> val.Value: - """Returns the Hugr representation of a boolean value.""" - return val.Sum(tag=int(b), value=val.Tuple(vs=[])) - - -def int_value(i: int) -> val.Value: - """Returns the Hugr representation of an integer value.""" - return val.Prim(val=val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),))) - - -def float_value(f: float) -> val.Value: - """Returns the Hugr representation of a float value.""" - return val.Prim(val=val.ExtensionVal(c=(ConstF64(value=f),))) - - -def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops.OpType: - """Utility method to create Hugr logic ops.""" - return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) - - -def int_op( - op_name: str, ext: str = "arithmetic.int", num_params: int = 1 -) -> ops.OpType: - """Utility method to create Hugr integer arithmetic ops.""" - return ops.CustomOp( - extension=ext, - op_name=op_name, - args=num_params * [tys.BoundedNatArg(n=INT_WIDTH)], - ) - - -def float_op(op_name: str, ext: str = "arithmetic.float") -> ops.OpType: - """Utility method to create Hugr integer arithmetic ops.""" - return ops.CustomOp(extension=ext, op_name=op_name, args=[]) - - -class CoercingChecker(DefaultCallChecker): - """Function call type checker thag automatically coerces arguments to float.""" - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - from .builtins import Int - - for i in range(len(args)): - args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) - if isinstance(ty, self.ctx.globals.types["int"]): - call = with_loc( - self.node, GlobalCall(func=Int.__float__, args=[args[i]]) - ) - args[i] = with_type(self.ctx.globals.types["float"].build(), call) - return super().synthesize(args) - - -class ReversingChecker(CustomCallChecker): - """Call checker that reverses the arguments after checking.""" - - base_checker: CustomCallChecker - - def __init__(self, base_checker: Optional[CustomCallChecker] = None): - self.base_checker = base_checker or DefaultCallChecker() - - def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: - super()._setup(ctx, node, func) - self.base_checker._setup(ctx, node, func) - - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - expr = self.base_checker.check(args, ty) - if isinstance(expr, GlobalCall): - expr.args = list(reversed(args)) - return expr - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - expr, ty = self.base_checker.synthesize(args) - if isinstance(expr, GlobalCall): - expr.args = list(reversed(args)) - return expr, ty - - -class UnsupportedChecker(CustomCallChecker): - """Call checker for Python builtin functions that are not available in Guppy. - - Gives the uses a nicer error message when they try to use an unsupported feature. - """ - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - raise GuppyError( - f"Builtin method `{self.func.name}` is not supported by Guppy", self.node - ) - - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - raise GuppyError( - f"Builtin method `{self.func.name}` is not supported by Guppy", self.node - ) - - -class DunderChecker(CustomCallChecker): - """Call checker for builtin functions that call out to dunder instance methods""" - - dunder_name: str - num_args: int - - def __init__(self, dunder_name: str, num_args: int = 1): - assert num_args > 0 - self.dunder_name = dunder_name - self.num_args = num_args - - def _get_func( - self, args: list[ast.expr] - ) -> tuple[list[ast.expr], CallableVariable]: - check_num_args(self.num_args, len(args), self.node) - fst, *rest = args - fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) - func = self.ctx.globals.get_instance_func(ty, self.dunder_name) - if func is None: - raise GuppyTypeError( - f"Builtin function `{self.func.name}` is not defined for argument of " - f"type `{ty}`", - self.node.args[0] if isinstance(self.node, ast.Call) else self.node, - ) - return [fst, *rest], func - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - args, func = self._get_func(args) - return func.synthesize_call(args, self.node, self.ctx) - - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - args, func = self._get_func(args) - return func.check_call(args, ty, self.node, self.ctx) - - -class CallableChecker(CustomCallChecker): - """Call checker for the builtin `callable` function""" - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - check_num_args(1, len(args), self.node) - [arg] = args - arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) - is_callable = ( - isinstance(ty, FunctionType) - or self.ctx.globals.get_instance_func(ty, "__call__") is not None - ) - const = with_loc(self.node, ast.Constant(value=is_callable)) - return const, BoolType() - - -class IntTruedivCompiler(CustomCallCompiler): - """Compiler for the `int.__truediv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Int, Float - - # Compile `truediv` using float arithmetic - [left, right] = args - [left] = Int.__float__.compile_call( - [left], self.dfg, self.graph, self.globals, self.node - ) - [right] = Int.__float__.compile_call( - [right], self.dfg, self.graph, self.globals, self.node - ) - return Float.__truediv__.compile_call( - [left, right], self.dfg, self.graph, self.globals, self.node - ) - - -class FloatBoolCompiler(CustomCallCompiler): - """Compiler for the `float.__bool__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: bool(x) = (x != 0.0) - zero_const = self.graph.add_constant( - float_value(0.0), get_type(self.node), self.dfg.node - ) - zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) - return Float.__ne__.compile_call( - [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node - ) - - -class FloatFloordivCompiler(CustomCallCompiler): - """Compiler for the `float.__floordiv__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: floordiv(x, y) = floor(truediv(x, y)) - [div] = Float.__truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [floor] = Float.__floor__.compile_call( - [div], self.dfg, self.graph, self.globals, self.node - ) - return [floor] - - -class FloatModCompiler(CustomCallCompiler): - """Compiler for the `float.__mod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: mod(x, y) = x - (x // y) * y - [div] = Float.__floordiv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [mul] = Float.__mul__.compile_call( - [div, args[1]], self.dfg, self.graph, self.globals, self.node - ) - [sub] = Float.__sub__.compile_call( - [args[0], mul], self.dfg, self.graph, self.globals, self.node - ) - return [sub] - - -class FloatDivmodCompiler(CustomCallCompiler): - """Compiler for the `__divmod__` method.""" - - def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Float - - # We have: divmod(x, y) = (div(x, y), mod(x, y)) - [div] = Float.__truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - [mod] = Float.__mod__.compile_call( - args, self.dfg, self.graph, self.globals, self.node - ) - return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index 25dfc741..98358bc9 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -1,731 +1,6 @@ """Guppy module for builtin types and operations.""" -# mypy: disable-error-code="empty-body, misc, override, no-untyped-def" - -from guppy.custom import NoopCompiler, DefaultCallChecker -from guppy.decorator import guppy -from guppy.gtypes import BoolType -from guppy.hugr import tys, ops from guppy.module import GuppyModule -from guppy.prelude._internal import ( - logic_op, - int_op, - hugr_int_type, - hugr_float_type, - float_op, - CoercingChecker, - ReversingChecker, - IntTruedivCompiler, - FloatBoolCompiler, - FloatDivmodCompiler, - FloatFloordivCompiler, - FloatModCompiler, - DunderChecker, - CallableChecker, - UnsupportedChecker, -) builtins = GuppyModule("builtins", import_builtins=False) - - -@guppy.extend_type(builtins, BoolType) -class Bool: - @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) - def __and__(self: bool, other: bool) -> bool: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __bool__(self: bool) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("ifrombool")) - def __int__(self: bool) -> int: - ... - - @guppy.hugr_op(builtins, logic_op("Or", [tys.BoundedNatArg(n=2)])) - def __or__(self: bool, other: bool) -> bool: - ... - - -@guppy.type(builtins, hugr_int_type, name="int") -class Int: - @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) - def __abs__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("iadd")) - def __add__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("iand")) - def __and__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("itobool")) - def __bool__(self: int) -> bool: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __ceil__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2)) - def __divmod__(self: int, other: int) -> tuple[int, int]: - ... - - @guppy.hugr_op(builtins, int_op("ieq")) - def __eq__(self: int, other: int) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("convert_s", "arithmetic.conversions")) - def __float__(self: int) -> float: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __floor__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2)) - def __floordiv__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ige_s")) - def __ge__(self: int, other: int) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("igt_s")) - def __gt__(self: int, other: int) -> bool: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __int__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("inot")) - def __invert__(self: int) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("ile_s")) - def __le__(self: int, other: int) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("ishl", num_params=2)) # TODO: RHS is unsigned - def __lshift__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ilt_s")) - def __lt__(self: int, other: int) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("imod_s", num_params=2)) - def __mod__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("imul")) - def __mul__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ine")) - def __ne__(self: int, other: int) -> bool: - ... - - @guppy.hugr_op(builtins, int_op("ineg")) - def __neg__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ior")) - def __or__(self: int, other: int) -> int: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __pos__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, ops.DummyOp(name="ipow")) # TODO - def __pow__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker()) - def __radd__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker()) - def __rand__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker()) - def __rdivmod__(self: int, other: int) -> tuple[int, int]: - ... - - @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker()) - def __rfloordiv__(self: int, other: int) -> int: - ... - - @guppy.hugr_op( - builtins, int_op("ishl", num_params=2), ReversingChecker() - ) # TODO: RHS is unsigned - def __rlshift__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker()) - def __rmod__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker()) - def __rmul__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ior"), ReversingChecker()) - def __ror__(self: int, other: int) -> int: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __round__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, ops.DummyOp(name="ipow"), ReversingChecker()) # TODO - def __rpow__(self: int, other: int) -> int: - ... - - @guppy.hugr_op( - builtins, int_op("ishr", num_params=2), ReversingChecker() - ) # TODO: RHS is unsigned - def __rrshift__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ishr", num_params=2)) # TODO: RHS is unsigned - def __rshift__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker()) - def __rsub__(self: int, other: int) -> int: - ... - - @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker()) - def __rtruediv__(self: int, other: int) -> float: - ... - - @guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker()) - def __rxor__(self: int, other: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("isub")) - def __sub__(self: int, other: int) -> int: - ... - - @guppy.custom(builtins, IntTruedivCompiler()) - def __truediv__(self: int, other: int) -> float: - ... - - @guppy.custom(builtins, NoopCompiler()) - def __trunc__(self: int) -> int: - ... - - @guppy.hugr_op(builtins, int_op("ixor")) - def __xor__(self: int, other: int) -> int: - ... - - -@guppy.type(builtins, hugr_float_type, name="float") -class Float: - @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) - def __abs__(self: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fadd"), CoercingChecker()) - def __add__(self: float, other: float) -> float: - ... - - @guppy.custom(builtins, FloatBoolCompiler(), CoercingChecker()) - def __bool__(self: float) -> bool: - ... - - @guppy.hugr_op(builtins, float_op("fceil"), CoercingChecker()) - def __ceil__(self: float) -> float: - ... - - @guppy.custom(builtins, FloatDivmodCompiler(), CoercingChecker()) - def __divmod__(self: float, other: float) -> tuple[float, float]: - ... - - @guppy.hugr_op(builtins, float_op("feq"), CoercingChecker()) - def __eq__(self: float, other: float) -> bool: - ... - - @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) - def __float__(self: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("ffloor"), CoercingChecker()) - def __floor__(self: float) -> float: - ... - - @guppy.custom(builtins, FloatFloordivCompiler(), CoercingChecker()) - def __floordiv__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fge"), CoercingChecker()) - def __ge__(self: float, other: float) -> bool: - ... - - @guppy.hugr_op(builtins, float_op("fgt"), CoercingChecker()) - def __gt__(self: float, other: float) -> bool: - ... - - @guppy.hugr_op( - builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() - ) - def __int__(self: float) -> int: - ... - - @guppy.hugr_op(builtins, float_op("fle"), CoercingChecker()) - def __le__(self: float, other: float) -> bool: - ... - - @guppy.hugr_op(builtins, float_op("flt"), CoercingChecker()) - def __lt__(self: float, other: float) -> bool: - ... - - @guppy.custom(builtins, FloatModCompiler(), CoercingChecker()) - def __mod__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fmul"), CoercingChecker()) - def __mul__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fne"), CoercingChecker()) - def __ne__(self: float, other: float) -> bool: - ... - - @guppy.hugr_op(builtins, float_op("fneg"), CoercingChecker()) - def __neg__(self: float, other: float) -> float: - ... - - @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) - def __pos__(self: float) -> float: - ... - - @guppy.hugr_op(builtins, ops.DummyOp(name="fpow")) # TODO - def __pow__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fadd"), ReversingChecker(CoercingChecker())) - def __radd__(self: float, other: float) -> float: - ... - - @guppy.custom(builtins, FloatDivmodCompiler(), ReversingChecker(CoercingChecker())) - def __rdivmod__(self: float, other: float) -> tuple[float, float]: - ... - - @guppy.custom( - builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker()) - ) - def __rfloordiv__(self: float, other: float) -> float: - ... - - @guppy.custom(builtins, FloatModCompiler(), ReversingChecker(CoercingChecker())) - def __rmod__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fmul"), ReversingChecker(CoercingChecker())) - def __rmul__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, ops.DummyOp(name="fround")) # TODO - def __round__(self: float) -> float: - ... - - @guppy.hugr_op( - builtins, ops.DummyOp(name="fpow"), ReversingChecker(DefaultCallChecker()) - ) # TODO - def __rpow__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fsub"), ReversingChecker(CoercingChecker())) - def __rsub__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fdiv"), ReversingChecker(CoercingChecker())) - def __rtruediv__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fsub"), CoercingChecker()) - def __sub__(self: float, other: float) -> float: - ... - - @guppy.hugr_op(builtins, float_op("fdiv"), CoercingChecker()) - def __truediv__(self: float, other: float) -> float: - ... - - @guppy.hugr_op( - builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() - ) - def __trunc__(self: float) -> float: - ... - - -@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False) -def abs(x): - ... - - -@guppy.custom( - builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False -) -def _bool(x): - ... - - -@guppy.custom(builtins, checker=CallableChecker(), higher_order_value=False) -def callable(x): - ... - - -@guppy.custom( - builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False -) -def divmod(x, y): - ... - - -@guppy.custom( - builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False -) -def _float(x, y): - ... - - -@guppy.custom( - builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False -) -def _int(x): - ... - - -@guppy.custom( - builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False -) -def pow(x, y): - ... - - -@guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) -def round(x): - ... - - -# Python builtins that are not supported yet: - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def aiter(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def all(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def anext(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def any(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def bin(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def breakpoint(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def bytearray(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def bytes(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def chr(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def classmethod(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def compile(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def complex(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def delattr(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def dict(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def dir(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def enumerate(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def eval(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def exec(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def filter(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def format(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def forozenset(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def getattr(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def globals(): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def hasattr(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def hash(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def help(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def hex(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def id(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def input(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def isinstance(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def issubclass(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def iter(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def len(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def list(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def locals(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def map(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def max(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def memoryview(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def min(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def next(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def object(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def oct(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def open(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def ord(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def print(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def property(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def range(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def repr(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def reversed(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def set(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def setattr(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def slice(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def sorted(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def staticmethod(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def str(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def sum(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def super(x): - ... - - -@guppy.custom( - builtins, name="tuple", checker=UnsupportedChecker(), higher_order_value=False -) -def _tuple(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def type(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def vars(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def zip(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def __import__(x): - ... diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index 3b40af88..e1ef0f08 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -3,7 +3,7 @@ # mypy: disable-error-code=empty-body from guppy.decorator import guppy -from guppy.hugr import tys, ops +from guppy.hugr import tys from guppy.hugr.tys import TypeBound from guppy.module import GuppyModule @@ -11,11 +11,6 @@ quantum = GuppyModule("quantum") -def quantum_op(op_name: str) -> ops.OpType: - """Utility method to create Hugr quantum ops.""" - return ops.CustomOp(extension="quantum.tket2", op_name=op_name, args=[]) - - @guppy.type( quantum, tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), @@ -25,41 +20,13 @@ class Qubit: pass -@guppy.hugr_op(quantum, quantum_op("H")) def h(q: Qubit) -> Qubit: ... -@guppy.hugr_op(quantum, quantum_op("CX")) def cx(control: Qubit, target: Qubit) -> tuple[Qubit, Qubit]: ... -@guppy.hugr_op(quantum, quantum_op("RzF64")) -def rz(q: Qubit, angle: float) -> Qubit: - ... - - -@guppy.hugr_op(quantum, quantum_op("Measure")) def measure(q: Qubit) -> tuple[Qubit, bool]: ... - - -@guppy.hugr_op(quantum, quantum_op("T")) -def t(q: Qubit) -> Qubit: - ... - - -@guppy.hugr_op(quantum, quantum_op("Tdg")) -def tdg(q: Qubit) -> Qubit: - ... - - -@guppy.hugr_op(quantum, quantum_op("Z")) -def z(q: Qubit) -> Qubit: - ... - - -@guppy.hugr_op(quantum, quantum_op("X")) -def x(q: Qubit) -> Qubit: - ... From 871b14c2ba8ddce8a9435f972bcfc7cd56c20fa4 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 11:57:10 +0000 Subject: [PATCH 21/77] Add type checking code --- guppy/checker/cfg_checker.py | 157 +++++++++++++++- guppy/checker/expr_checker.py | 341 ++++++++++++++++++++++++++++++++++ guppy/checker/func_checker.py | 139 +++++++++++++- guppy/checker/stmt_checker.py | 126 +++++++++++++ guppy/custom.py | 10 +- guppy/declared.py | 9 +- 6 files changed, 766 insertions(+), 16 deletions(-) create mode 100644 guppy/checker/expr_checker.py create mode 100644 guppy/checker/stmt_checker.py diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 9cd0f41e..1fbb0521 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -1,11 +1,16 @@ +import collections from dataclasses import dataclass from typing import Sequence +from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Globals +from guppy.checker.core import Globals, Context from guppy.checker.core import Variable +from guppy.checker.expr_checker import ExprSynthesizer, to_bool +from guppy.checker.stmt_checker import StmtChecker +from guppy.error import GuppyError from guppy.gtypes import GuppyType @@ -52,7 +57,53 @@ def check_cfg( Annotates the basic blocks with input and output type signatures and removes unreachable blocks. """ - raise NotImplementedError + # First, we need to run program analysis + ass_before = set(v.name for v in inputs) + cfg.analyze(ass_before, ass_before) + + # We start by compiling the entry BB + checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) + checked_cfg.entry_bb = check_bb( + cfg.entry_bb, checked_cfg, inputs, return_ty, globals + ) + compiled = {cfg.entry_bb: checked_cfg.entry_bb} + + # Visit all control-flow edges in BFS order. We can't just do a normal loop over + # all BBs since the input types for a BB are computed by checking a predecessor. + # We do BFS instead of DFS to get a better error ordering. + queue = collections.deque( + (checked_cfg.entry_bb, i, succ) + for i, succ in enumerate(cfg.entry_bb.successors) + ) + while len(queue) > 0: + pred, num_output, bb = queue.popleft() + input_row = [ + Variable(v.name, v.ty, v.defined_at, None) + for v in pred.sig.output_rows[num_output] + ] + + if bb in compiled: + # If the BB was already compiled, we just have to check that the signatures + # match. + check_rows_match(input_row, compiled[bb].sig.input_row, bb) + else: + # Otherwise, check the BB and enqueue its successors + checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) + queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)] + compiled[bb] = checked_bb + + # Link up BBs in the checked CFG + compiled[bb].predecessors.append(pred) + pred.successors[num_output] = compiled[bb] + + checked_cfg.bbs = list(compiled.values()) + checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable + checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} + checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} + checked_cfg.maybe_ass_before = { + compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs + } + return checked_cfg def check_bb( @@ -62,4 +113,104 @@ def check_bb( return_ty: GuppyType, globals: Globals, ) -> CheckedBB: - raise NotImplementedError + cfg = bb.cfg + + # For the entry BB we have to separately check that all used variables are + # defined. For all other BBs, this will be checked when compiling a predecessor. + if len(bb.predecessors) == 0: + for x, use in bb.vars.used.items(): + if x not in cfg.ass_before[bb] and x not in globals.values: + raise GuppyError(f"Variable `{x}` is not defined", use) + + # Check the basic block + ctx = Context(globals, {v.name: v for v in inputs}) + checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) + + # If we branch, we also have to check the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + bb.branch_pred, ty = ExprSynthesizer(ctx).synthesize(bb.branch_pred) + bb.branch_pred, _ = to_bool(bb.branch_pred, ty, ctx) + + for succ in bb.successors: + for x, use_bb in cfg.live_before[succ].items(): + # Check that the variables requested by the successor are defined + if x not in ctx.locals and x not in ctx.globals.values: + # If the variable is defined on *some* paths, we can give a more + # informative error message + if x in cfg.maybe_ass_before[use_bb]: + # TODO: This should be "Variable x is not defined when coming + # from {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` is not defined on all control-flow paths.", + use_bb.vars.used[x], + ) + raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x]) + + # We have to check that used linear variables are not being outputted + if x in ctx.locals: + var = ctx.locals[x] + if var.ty.linear and var.used: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + cfg.live_before[succ][x].vars.used[x], + [var.used], + ) + + # On the other hand, unused linear variables *must* be outputted + for x, var in ctx.locals.items(): + if var.ty.linear and not var.used and x not in cfg.live_before[succ]: + # TODO: This should be "Variable x with linear type ty is not + # used in {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is " + "not used on all control-flow paths", + var.defined_at, + ) + + # Finally, we need to compute the signature of the basic block + outputs = [ + [ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals] + for succ in bb.successors + ] + + # Also prepare the successor list so we can fill it in later + checked_bb = CheckedBB( + bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) + ) + checked_bb.successors = [None] * len(bb.successors) # type: ignore + checked_bb.branch_pred = bb.branch_pred + return checked_bb + + +def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: + """Checks that the types of two rows match up. + + Otherwise, an error is thrown, alerting the user that a variable has different + types on different control-flow paths. + """ + map1, map2 = {v.name: v for v in row1}, {v.name: v for v in row2} + assert map1.keys() == map2.keys() + for x in map1: + v1, v2 = map1[x], map2[x] + if v1.ty != v2.ty: + # In the error message, we want to mention the variable that was first + # defined at the start. + if ( + v1.defined_at + and v2.defined_at + and line_col(v2.defined_at) < line_col(v1.defined_at) + ): + v1, v2 = v2, v1 + # We shouldn't mention temporary variables (starting with `%`) + # in error messages: + ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" + raise GuppyError( + f"{ident} can refer to different types: " + f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", + bb.cfg.live_before[bb][v1.name].vars.used[v1.name], + [v1.defined_at, v2.defined_at], + ) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py new file mode 100644 index 00000000..747d32c0 --- /dev/null +++ b/guppy/checker/expr_checker.py @@ -0,0 +1,341 @@ +import ast +from typing import Optional, Union, NoReturn, Any + +from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt +from guppy.checker.core import Context, CallableVariable, Globals +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType +from guppy.nodes import LocalName, GlobalName, LocalCall + +# Mapping from unary AST op to dunder method and display name +unary_table: dict[type[ast.unaryop], tuple[str, str]] = { + ast.UAdd: ("__pos__", "+"), + ast.USub: ("__neg__", "-"), + ast.Invert: ("__invert__", "~"), +} + +# Mapping from binary AST op to left dunder method, right dunder method and display name +AstOp = Union[ast.operator, ast.cmpop] +binary_table: dict[type[AstOp], tuple[str, str, str]] = { + ast.Add: ("__add__", "__radd__", "+"), + ast.Sub: ("__sub__", "__rsub__", "-"), + ast.Mult: ("__mul__", "__rmul__", "*"), + ast.Div: ("__truediv__", "__rtruediv__", "/"), + ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), + ast.Mod: ("__mod__", "__rmod__", "%"), + ast.Pow: ("__pow__", "__rpow__", "**"), + ast.LShift: ("__lshift__", "__rlshift__", "<<"), + ast.RShift: ("__rshift__", "__rrshift__", ">>"), + ast.BitOr: ("__or__", "__ror__", "|"), + ast.BitXor: ("__xor__", "__rxor__", "^"), + ast.BitAnd: ("__and__", "__rand__", "&"), + ast.MatMult: ("__matmul__", "__rmatmul__", "@"), + ast.Eq: ("__eq__", "__eq__", "=="), + ast.NotEq: ("__neq__", "__neq__", "!="), + ast.Lt: ("__lt__", "__gt__", "<"), + ast.LtE: ("__le__", "__ge__", "<="), + ast.Gt: ("__gt__", "__lt__", ">"), + ast.GtE: ("__ge__", "__le__", ">="), +} + + +class ExprChecker(AstVisitor[ast.expr]): + """Checks an expression against a type and produces a new type-annotated AST""" + + ctx: Context + + # Name for the kind of term we are currently checking against (used in errors). + # For example, "argument", "return value", or in general "expression". + _kind: str + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + self._kind = "expression" + + def _fail( + self, + expected: GuppyType, + actual: Union[ast.expr, GuppyType], + loc: Optional[AstNode] = None, + ) -> NoReturn: + """Raises a type error indicating that the type doesn't match.""" + if not isinstance(actual, GuppyType): + loc = loc or actual + _, actual = self._synthesize(actual) + assert loc is not None + raise GuppyTypeError( + f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc + ) + + def check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + """Checks an expression against a type. + + Returns a new desugared expression with type annotations. + """ + old_kind = self._kind + self._kind = kind or self._kind + expr = self.visit(expr, ty) + self._kind = old_kind + return with_type(ty, expr) + + def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Invokes the type synthesiser""" + return ExprSynthesizer(self.ctx).synthesize(node) + + def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: + if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): + return self._fail(ty, node) + for i, el in enumerate(node.elts): + node.elts[i] = self.check(el, ty.element_types[i]) + return node + + def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore + # Try to synthesize and then check if it matches the given type + node, synth = self._synthesize(node) + if synth != ty: + self._fail(ty, synth, node) + return node + + +class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): + ctx: Context + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + + def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Tries to synthesise a type for the given expression. + + Also returns a new desugared expression with type annotations. + """ + if ty := get_type_opt(node): + return node, ty + node, ty = self.visit(node) + return with_type(ty, node), ty + + def _check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + """Checks an expression against a given type""" + return ExprChecker(self.ctx).check(expr, ty, kind) + + def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: + ty = python_value_to_guppy_type(node.value, node, self.ctx.globals) + if ty is None: + raise GuppyError("Unsupported constant", node) + return node, ty + + def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + x = node.id + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is not None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + node, + [var.used], + ) + var.used = node + return with_loc(node, LocalName(id=x)), var.ty + elif x in self.ctx.globals.values: + # Cache value in the AST + value = self.ctx.globals.values[x] + return with_loc(node, GlobalName(id=x, value=value)), value.ty + raise InternalGuppyError( + f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " + f"been caught by program analysis!" + ) + + def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: + elems = [self.synthesize(elem) for elem in node.elts] + node.elts = [n for n, _ in elems] + return node, TupleType([ty for _, ty in elems]) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: + # We need to synthesise the argument type, so we can look up dunder methods + node.operand, op_ty = self.synthesize(node.operand) + + # Special case for the `not` operation since it is not implemented via a dunder + # method or control-flow + if isinstance(node.op, ast.Not): + node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx) + return node, bool_ty + + # Check all other unary expressions by calling out to instance dunder methods + op, display_name = unary_table[node.op.__class__] + func = self.ctx.globals.get_instance_func(op_ty, op) + if func is None: + raise GuppyTypeError( + f"Unary operator `{display_name}` not defined for argument of type " + f" `{op_ty}`", + node.operand, + ) + return func.synthesize_call([node.operand], node, self.ctx) + + def _synthesize_binary( + self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr + ) -> tuple[ast.expr, GuppyType]: + """Helper method to compile binary operators by calling out to dunder methods. + + For example, first try calling `__add__` on the left operand. If that fails, try + `__radd__` on the right operand. + """ + if op.__class__ not in binary_table: + raise GuppyError("This binary operation is not supported by Guppy.", op) + lop, rop, display_name = binary_table[op.__class__] + left_expr, left_ty = self.synthesize(left_expr) + right_expr, right_ty = self.synthesize(right_expr) + + if func := self.ctx.globals.get_instance_func(left_ty, lop): + try: + return func.synthesize_call([left_expr, right_expr], node, self.ctx) + except GuppyError: + pass + + if func := self.ctx.globals.get_instance_func(right_ty, rop): + try: + return func.synthesize_call([right_expr, left_expr], node, self.ctx) + except GuppyError: + pass + + raise GuppyTypeError( + f"Binary operator `{display_name}` not defined for arguments of type " + f"`{left_ty}` and `{right_ty}`", + node, + ) + + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: + return self._synthesize_binary(node.left, node.right, node.op, node) + + def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: + if len(node.comparators) != 1 or len(node.ops) != 1: + raise InternalGuppyError( + "BB contains chained comparison. Should have been removed during CFG " + "construction." + ) + left_expr, [op], [right_expr] = node.left, node.ops, node.comparators + return self._synthesize_binary(left_expr, right_expr, op, node) + + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: + if len(node.keywords) > 0: + raise GuppyError( + "Argument passing by keyword is not supported", node.keywords[0] + ) + node.func, ty = self.synthesize(node.func) + + # First handle direct calls of user-defined functions and extension functions + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): + return node.func.value.synthesize_call(node.args, node, self.ctx) + + # Otherwise, it must be a function as a higher-order value + if isinstance(ty, FunctionType): + args, return_ty = synthesize_call(ty, node.args, node, self.ctx) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif f := self.ctx.globals.get_instance_func(ty, "__call__"): + return f.synthesize_call(node.args, node, self.ctx) + else: + raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `NamedExpr`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + + def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `BoolOp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `IfExp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + +def check_num_args(exp: int, act: int, node: AstNode) -> None: + """Checks that the correct number of arguments have been passed to a function.""" + if act < exp: + raise GuppyTypeError( + f"Not enough arguments passed (expected {exp}, got {act})", node + ) + if exp < act: + if isinstance(node, ast.Call): + raise GuppyTypeError("Unexpected argument", node.args[exp]) + raise GuppyTypeError( + f"Too many arguments passed (expected {exp}, got {act})", node + ) + + +def synthesize_call( + func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context +) -> tuple[list[ast.expr], GuppyType]: + """Synthesizes the return type of a function call""" + check_num_args(len(func_ty.args), len(args), node) + for i, arg in enumerate(args): + args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") + return args, func_ty.returns + + +def check_call( + func_ty: FunctionType, + args: list[ast.expr], + ty: GuppyType, + node: AstNode, + ctx: Context, +) -> list[ast.expr]: + """Checks the return type of a function call against a given type""" + args, return_ty = synthesize_call(func_ty, args, node, ctx) + if return_ty != ty: + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `{return_ty}`", node + ) + return args + + +def to_bool( + node: ast.expr, node_ty: GuppyType, ctx: Context +) -> tuple[ast.expr, GuppyType]: + """Tries to turn a node into a bool""" + if isinstance(node_ty, BoolType): + return node, node_ty + + func = ctx.globals.get_instance_func(node_ty, "__bool__") + if func is None: + raise GuppyTypeError( + f"Expression of type `{node_ty}` cannot be interpreted as a `bool`", + node, + ) + + # We could check the return type against bool, but we can give a better error + # message if we synthesise and compare to bool by hand + call, return_ty = func.synthesize_call([node], node, ctx) + if not isinstance(return_ty, BoolType): + raise GuppyTypeError( + f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", + node, + ) + return call, return_ty + + +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals +) -> Optional[GuppyType]: + """Turns a primitive Python value into a Guppy type. + + Returns `None` if the Python value cannot be represented in Guppy. + """ + if isinstance(v, bool): + return globals.types["bool"].build(node=node) + elif isinstance(v, int): + return globals.types["int"].build(node=node) + if isinstance(v, float): + return globals.types["float"].build(node=node) + return None diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index cb2d3234..87b22b8f 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -1,11 +1,14 @@ import ast from dataclasses import dataclass -from guppy.ast_util import AstNode +from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc from guppy.cfg.bb import BB -from guppy.checker.core import Globals, Context, CallableVariable -from guppy.checker.cfg_checker import CheckedCFG -from guppy.gtypes import FunctionType, GuppyType +from guppy.cfg.builder import CFGBuilder +from guppy.checker.core import Variable, Globals, Context, CallableVariable +from guppy.checker.cfg_checker import check_cfg, CheckedCFG +from guppy.checker.expr_checker import synthesize_call, check_call +from guppy.error import GuppyError +from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef @@ -26,12 +29,16 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context ) -> GlobalCall: - raise NotImplementedError + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: - raise NotImplementedError + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty @dataclass @@ -43,17 +50,131 @@ class CheckedFunction(DefinedFunction): def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFunction: """Type checks a top-level function definition.""" - raise NotImplementedError + func_def = func.defined_at + args = func_def.args.args + returns_none = isinstance(func.ty.returns, NoneType) + assert func.ty.arg_names is not None + + cfg = CFGBuilder().build(func_def.body, returns_none, globals) + inputs = [ + Variable(x, ty, loc, None) + for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) + ] + cfg = check_cfg(cfg, inputs, func.ty.returns, globals) + return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) def check_nested_func_def( func_def: NestedFunctionDef, bb: BB, ctx: Context ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" - raise NotImplementedError + func_ty = check_signature(func_def, ctx.globals) + assert func_ty.arg_names is not None + + # We've already built the CFG for this function while building the CFG of the + # enclosing function + cfg = func_def.cfg + + # Find captured variables + parent_cfg = bb.cfg + def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() + maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] + cfg.analyze(def_ass_before, maybe_ass_before) + captured = { + x: ctx.locals[x] + for x in cfg.live_before[cfg.entry_bb] + if x not in func_ty.arg_names and x in ctx.locals + } + + # Captured variables may not be linear + for v in captured.values(): + if v.ty.linear: + x = v.name + using_bb = cfg.live_before[cfg.entry_bb][x] + raise GuppyError( + f"Variable `{x}` with linear type `{v.ty}` may not be used here " + f"because it was defined in an outer scope (at {{0}})", + using_bb.vars.used[x], + [v.defined_at], + ) + + # Captured variables may never be assigned to + for bb in cfg.bbs: + for v in captured.values(): + x = v.name + if x in bb.vars.assigned: + raise GuppyError( + f"Variable `{x}` defined in an outer scope (at {{0}}) may not " + f"be assigned to", + bb.vars.assigned[x], + [v.defined_at], + ) + + # Construct inputs for checking the body CFG + inputs = list(captured.values()) + [ + Variable(x, ty, func_def.args.args[i], None) + for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) + ] + globals = ctx.globals + + # Check if the body contains a recursive occurrence of the function name + if func_def.name in cfg.live_before[cfg.entry_bb]: + if len(captured) == 0: + # If there are no captured vars, we treat the function like a global name + func = DefinedFunction(func_def.name, func_ty, func_def, None) + globals = ctx.globals | Globals({func_def.name: func}, {}) + + else: + # Otherwise, we treat it like a local name + inputs.append(Variable(func_def.name, func_def.ty, func_def, None)) + + checked_cfg = check_cfg(cfg, inputs, func_ty.returns, globals) + checked_def = CheckedNestedFunctionDef( + checked_cfg, + func_ty, + captured, + name=func_def.name, + args=func_def.args, + body=func_def.body, + decorator_list=func_def.decorator_list, + returns=func_def.returns, + type_comment=func_def.type_comment, + ) + return with_loc(func_def, checked_def) def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: """Checks the signature of a function definition and returns the corresponding Guppy type.""" - raise NotImplementedError + if len(func_def.args.posonlyargs) != 0: + raise GuppyError( + "Positional-only parameters not supported", func_def.args.posonlyargs[0] + ) + if len(func_def.args.kwonlyargs) != 0: + raise GuppyError( + "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] + ) + if func_def.args.vararg is not None: + raise GuppyError("*args not supported", func_def.args.vararg) + if func_def.args.kwarg is not None: + raise GuppyError("**kwargs not supported", func_def.args.kwarg) + if func_def.returns is None: + # TODO: Error location is incorrect + if all(r.value is None for r in return_nodes_in_ast(func_def)): + raise GuppyError( + "Return type must be annotated. Try adding a `-> None` annotation.", + func_def, + ) + raise GuppyError("Return type must be annotated", func_def) + + arg_tys = [] + arg_names = [] + for i, arg in enumerate(func_def.args.args): + if arg.annotation is None: + raise GuppyError("Argument type must be annotated", arg) + ty = type_from_ast(arg.annotation, globals) + arg_tys.append(ty) + arg_names.append(arg.arg) + + ret_type = type_from_ast(func_def.returns, globals) + return FunctionType(arg_tys, ret_type, arg_names) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py new file mode 100644 index 00000000..5b12dc58 --- /dev/null +++ b/guppy/checker/stmt_checker.py @@ -0,0 +1,126 @@ +import ast +from typing import Sequence + +from guppy.ast_util import with_loc, AstVisitor +from guppy.cfg.bb import BB, BBStatement +from guppy.checker.core import Variable, Context +from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType +from guppy.nodes import NestedFunctionDef + + +class StmtChecker(AstVisitor[BBStatement]): + ctx: Context + bb: BB + return_ty: GuppyType + + def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: + self.ctx = ctx + self.bb = bb + self.return_ty = return_ty + + def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: + return [self.visit(s) for s in stmts] + + def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + return ExprSynthesizer(self.ctx).synthesize(node) + + def _check_expr( + self, node: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + return ExprChecker(self.ctx).check(node, ty, kind) + + def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: + """Helper function to check assignments with patterns.""" + # Easiest case is if the LHS pattern is a single variable. + if isinstance(lhs, ast.Name): + # Check if we override an unused linear variable + x = lhs.id + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + self.ctx.locals[x] = Variable(x, ty, node, None) + # The only other thing we support right now are tuples + elif isinstance(lhs, ast.Tuple): + tys = ty.element_types if isinstance(ty, TupleType) else [ty] + n, m = len(lhs.elts), len(tys) + if n != m: + raise GuppyTypeError( + f"{'Too many' if n < m else 'Not enough'} values to unpack " + f"(expected {n}, got {m})", + node, + ) + for pat, el_ty in zip(lhs.elts, tys): + self._check_assign(pat, el_ty, node) + # TODO: Python also supports assignments like `[a, b] = [1, 2]` or + # `a, *b = ...`. The former would require some runtime checks but + # the latter should be easier to do (unpack and repack the rest). + else: + raise GuppyError("Assignment pattern not supported", lhs) + + def visit_Assign(self, node: ast.Assign) -> ast.stmt: + if len(node.targets) > 1: + # This is the case for assignments like `a = b = 1` + raise GuppyError("Multi assignment not supported", node) + + [target] = node.targets + node.value, ty = self._synth_expr(node.value) + self._check_assign(target, ty, node) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: + if node.value is None: + raise GuppyError( + "Variable declaration is not supported. Assignment is required", node + ) + ty = type_from_ast(node.annotation, self.ctx.globals) + node.value = self._check_expr(node.value, ty) + self._check_assign(node.target, ty, node) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: + bin_op = with_loc( + node, ast.BinOp(left=node.target, op=node.op, right=node.value) + ) + assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) + return self.visit_Assign(assign) + + def visit_Expr(self, node: ast.Expr) -> ast.stmt: + # An expression statement where the return value is discarded + node.value, _ = self._synth_expr(node.value) + return node + + def visit_Return(self, node: ast.Return) -> ast.stmt: + if node.value is not None: + node.value = self._check_expr(node.value, self.return_ty, "return value") + elif not isinstance(self.return_ty, NoneType): + raise GuppyTypeError( + f"Expected return value of type `{self.return_ty}`", None + ) + return node + + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: + from guppy.checker.func_checker import check_nested_func_def + + func_def = check_nested_func_def(node, self.bb, self.ctx) + self.ctx.locals[func_def.name] = Variable( + func_def.name, func_def.ty, func_def, None + ) + return func_def + + def visit_If(self, node: ast.If) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_While(self, node: ast.While) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Break(self, node: ast.Break) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Continue(self, node: ast.Continue) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") diff --git a/guppy/custom.py b/guppy/custom.py index 4027232f..55a5896e 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -4,6 +4,7 @@ from guppy.ast_util import AstNode, with_type, with_loc from guppy.checker.core import Context, Globals +from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import ( @@ -15,6 +16,7 @@ from guppy.gtypes import GuppyType, FunctionType from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode +from guppy.nodes import GlobalCall class CustomFunction(CompiledFunction): @@ -207,10 +209,14 @@ class DefaultCallChecker(CustomCallChecker): """Checks function calls by comparing to a type signature.""" def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - raise NotImplementedError + # Use default implementation from the expression checker + args = check_call(self.func.ty, args, ty, self.node, self.ctx) + return GlobalCall(func=self.func, args=args) def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - raise NotImplementedError + # Use default implementation from the expression checker + args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) + return GlobalCall(func=self.func, args=args), ty class DefaultCallCompiler(CustomCallCompiler): diff --git a/guppy/declared.py b/guppy/declared.py index 0106ed21..91b5283e 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -4,6 +4,7 @@ from guppy.ast_util import AstNode, has_empty_body from guppy.checker.core import Globals, Context +from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError @@ -32,12 +33,16 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context ) -> GlobalCall: - raise NotImplementedError + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: - raise NotImplementedError + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) From aa3385796f539cd55223bbf874a081dcb5c04f0f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 12:10:39 +0000 Subject: [PATCH 22/77] Add _internal file --- guppy/prelude/_internal.py | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 guppy/prelude/_internal.py diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py new file mode 100644 index 00000000..ee31d46e --- /dev/null +++ b/guppy/prelude/_internal.py @@ -0,0 +1,39 @@ +from typing import Literal + +from pydantic import BaseModel + +from guppy.hugr import val + + +INT_WIDTH = 6 # 2^6 = 64 bit + + +class ConstIntS(BaseModel): + """Hugr representation of signed integers in the arithmetic extension.""" + + c: Literal["ConstIntS"] = "ConstIntS" + log_width: int + value: int + + +class ConstF64(BaseModel): + """Hugr representation of floats in the arithmetic extension.""" + + c: Literal["ConstF64"] = "ConstF64" + value: float + + +def bool_value(b: bool) -> val.Value: + """Returns the Hugr representation of a boolean value.""" + return val.Sum(tag=int(b), value=val.Tuple(vs=[])) + + +def int_value(i: int) -> val.Value: + """Returns the Hugr representation of an integer value.""" + return val.Prim(val=val.ExtensionVal(c=(ConstIntS(log_width=INT_WIDTH, value=i),))) + + +def float_value(f: float) -> val.Value: + """Returns the Hugr representation of a float value.""" + return val.Prim(val=val.ExtensionVal(c=(ConstF64(value=f),))) + From aabc6ca57ad8d42f39fc27f608abfcfbe6abaaa5 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 22 Nov 2023 12:11:08 +0000 Subject: [PATCH 23/77] Fix formatting --- guppy/prelude/_internal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index ee31d46e..66b40574 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -36,4 +36,3 @@ def int_value(i: int) -> val.Value: def float_value(f: float) -> val.Value: """Returns the Hugr representation of a float value.""" return val.Prim(val=val.ExtensionVal(c=(ConstF64(value=f),))) - From 21585fa60060a4fe516db20d035eb201a25a3181 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 24 Nov 2023 14:34:09 +0000 Subject: [PATCH 24/77] Work on polymorphism --- guppy/checker/core.py | 17 +- guppy/checker/expr_checker.py | 220 ++++++++++++++++++---- guppy/checker/func_checker.py | 27 ++- guppy/checker/stmt_checker.py | 11 +- guppy/custom.py | 24 ++- guppy/declared.py | 12 +- guppy/decorator.py | 24 ++- guppy/error.py | 4 +- guppy/gtypes.py | 344 ++++++++++++++++++++++++++++++---- guppy/module.py | 25 ++- guppy/nodes.py | 18 +- 11 files changed, 614 insertions(+), 112 deletions(-) diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 9bab6375..6c9c8957 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -10,7 +10,7 @@ TupleType, SumType, NoneType, - BoolType, + BoolType, Subst, ) @@ -33,7 +33,7 @@ class CallableVariable(ABC, Variable): @abstractmethod def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: """Checks the return type of a function call against a given type.""" @abstractmethod @@ -43,6 +43,14 @@ def synthesize_call( """Synthesizes the return type of a function call.""" +@dataclass +class TypeVarDecl: + """A declared type variable.""" + + name: str + linear: bool + + class Globals(NamedTuple): """Collection of names that are available on module-level. @@ -52,6 +60,7 @@ class Globals(NamedTuple): values: dict[str, Variable] types: dict[str, type[GuppyType]] + type_vars: dict[str, TypeVarDecl] @staticmethod def default() -> "Globals": @@ -63,7 +72,7 @@ def default() -> "Globals": NoneType.name: NoneType, BoolType.name: BoolType, } - return Globals({}, tys) + return Globals({}, tys, {}) def get_instance_func(self, ty: GuppyType, name: str) -> Optional[CallableVariable]: """Looks up an instance function with a given name for a type. @@ -81,11 +90,13 @@ def __or__(self, other: "Globals") -> "Globals": return Globals( self.values | other.values, self.types | other.types, + self.type_vars | other.type_vars ) def __ior__(self, other: "Globals") -> "Globals": self.values.update(other.values) self.types.update(other.types) + self.type_vars.update(other.type_vars) return self diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 747d32c0..350a0ba6 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -4,8 +4,9 @@ from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType -from guppy.nodes import LocalName, GlobalName, LocalCall +from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType, Subst, \ + FreeTypeVar, unify, Inst +from guppy.nodes import LocalName, GlobalName, LocalCall, TypeApply # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -39,7 +40,7 @@ } -class ExprChecker(AstVisitor[ast.expr]): +class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]): """Checks an expression against a type and produces a new type-annotated AST""" ctx: Context @@ -61,7 +62,7 @@ def _fail( """Raises a type error indicating that the type doesn't match.""" if not isinstance(actual, GuppyType): loc = loc or actual - _, actual = self._synthesize(actual) + _, actual = self._synthesize(actual, allow_free_vars=True) assert loc is not None raise GuppyTypeError( f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc @@ -69,34 +70,106 @@ def _fail( def check( self, expr: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: """Checks an expression against a type. - Returns a new desugared expression with type annotations. + The type may have free type variables which will try to be resolved. Returns + a new desugared expression with type annotations and a substitution with the + resolved type variables. """ + # When checking against a variable, we have to synthesize + if isinstance(ty, FreeTypeVar): + expr, syn_ty = self._synthesize(expr, allow_free_vars=False) + return with_type(syn_ty, expr), {ty.id: syn_ty} + + # Otherwise, invoke the visitor old_kind = self._kind self._kind = kind or self._kind - expr = self.visit(expr, ty) + expr, subst = self.visit(expr, ty) self._kind = old_kind - return with_type(ty, expr) + return with_type(ty.substitute(subst), expr), subst - def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + def _synthesize(self, node: ast.expr, allow_free_vars: bool) -> tuple[ast.expr, GuppyType]: """Invokes the type synthesiser""" - return ExprSynthesizer(self.ctx).synthesize(node) + return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) - def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: + def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): return self._fail(ty, node) + subst: Subst = {} for i, el in enumerate(node.elts): - node.elts[i] = self.check(el, ty.element_types[i]) - return node + node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) + subst |= s + return node, subst + + def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: + if len(node.keywords) > 0: + raise GuppyError( + "Argument passing by keyword is not supported", node.keywords[0] + ) + node.func, func_ty = self._synthesize(node.func, allow_free_vars=False) - def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore - # Try to synthesize and then check if it matches the given type - node, synth = self._synthesize(node) - if synth != ty: + # First handle direct calls of user-defined functions and extension functions + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): + return node.func.value.check_call(node.args, ty, node, self.ctx) + + # Otherwise, it must be a function as a higher-order value + if isinstance(func_ty, FunctionType): + args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx) + # Maybe we have to add a TypeApply node + if len(inst) > 0: + func = TypeApply(value=with_type(ty, node.func), tys=inst) + func = with_type(func_ty.instantiate(inst), func) + node.func = with_loc(node.func, func) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif f := self.ctx.globals.get_instance_func(func_ty, "__call__"): + return f.check_call(node.args, ty, node, self.ctx) + else: + raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) + + def generic_visit( # type: ignore + self, node: ast.expr, ty: GuppyType + ) -> tuple[ast.expr, Subst]: + # Try to synthesize and then check if we can unify it with the given type + node, synth = self._synthesize(node, allow_free_vars=False) + + # Special case if we synthesized a polymorphic function type. In that case, we + # have to find an instantiation to avoid higher-rank types. + subst: Optional[Subst] + if isinstance(synth, FunctionType) and synth.quantified: + unquantified, free_vars = synth.unquantified() + subst = unify(ty, unquantified, {}) + if subst is None: + self._fail(ty, synth, node) + # Check that we have found a valid instantiation for all quantified vars + for i, v in enumerate(free_vars): + if v.id not in subst: + raise GuppyTypeError( + f"Expected {self._kind} of type `{ty}`, got `{synth}`. " + "Couldn't infer an instantiation for type variable " + f"`{synth.quantified[i]}`", + node + ) + if subst[v.id].free_vars: + raise GuppyTypeError( + f"Expected {self._kind} of type `{ty}`, got `{synth}`. Can't " + f"instantiate type variable `{synth.quantified[i]}` with type " + f"`{subst[v.id]}` containing free variables", + node + ) + inst = [subst[v.id] for v in free_vars] + node = with_loc(node, TypeApply(value=node, tys=inst)) + subst = {v: t for v, t in subst.items() if v in ty.free_vars} + return node, subst + + # Otherwise, we know that `synth` has no free type vars, so unification is + # trivial + subst = unify(ty, synth, {}) + if subst is None: self._fail(ty, synth, node) - return node + return node, subst class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): @@ -105,7 +178,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): def __init__(self, ctx: Context) -> None: self.ctx = ctx - def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + def synthesize(self, node: ast.expr, allow_free_vars: bool = False) -> tuple[ast.expr, GuppyType]: """Tries to synthesise a type for the given expression. Also returns a new desugared expression with type annotations. @@ -113,11 +186,16 @@ def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: if ty := get_type_opt(node): return node, ty node, ty = self.visit(node) + if ty.free_vars and not allow_free_vars: + raise GuppyTypeError( + f"Cannot infer type variable in expression of type `{ty}`", + node + ) return with_type(ty, node), ty def _check( self, expr: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: """Checks an expression against a given type""" return ExprChecker(self.ctx).check(expr, ty, kind) @@ -234,7 +312,12 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: # Otherwise, it must be a function as a higher-order value if isinstance(ty, FunctionType): - args, return_ty = synthesize_call(ty, node.args, node, self.ctx) + args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) + # Maybe we have to add a TypeApply node + if len(inst) > 0: + func = TypeApply(value=with_type(ty, node.func), tys=inst) + func = with_type(ty.instantiate(inst), func) + node.func = with_loc(node.func, func) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) @@ -274,14 +357,59 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: ) +def type_check_args(args: list[ast.expr], func_ty: FunctionType, subst: Subst, ctx: Context, node: AstNode) -> tuple[list[ast.expr], Subst]: + """Checks the arguments of a function call and infers free type variables. + + We expect that quantified variables have been replaced with free unification + variables. Checks that all unification variables can be inferred. + """ + assert not func_ty.quantified + check_num_args(len(func_ty.args), len(args), node) + + new_args: list[ast.expr] = [] + for arg, ty in zip(args, func_ty.args): + a, s = ExprChecker(ctx).check(arg, ty.substitute(subst), "argument") + new_args.append(a) + subst |= s + + # If the argument check succeeded, this means that we must have found instantiations + # for all unification variables occurring in the argument types + assert all( + set.issubset(set(arg.free_vars.keys()), subst.keys()) + for arg in func_ty.args + ) + + # We just have to check that we found instantiations for all vars in the return type + if not set.issubset(set(func_ty.returns.free_vars.keys()), subst.keys()): + raise GuppyTypeError( + f"Cannot infer type variable in expression of type " + f"`{func_ty.returns.substitute(subst)}`", + node + ) + + return new_args, subst + + def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context -) -> tuple[list[ast.expr], GuppyType]: - """Synthesizes the return type of a function call""" +) -> tuple[list[ast.expr], GuppyType, Inst]: + """Synthesizes the return type of a function call. + + Returns an annotated argument list, the synthesized return type, and an + instantiation for the quantifiers in the function type. + """ + assert not func_ty.free_vars check_num_args(len(func_ty.args), len(args), node) - for i, arg in enumerate(args): - args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") - return args, func_ty.returns + + # Replace quantified variables with free unification variables and try to infer an + # instantiation by checking the arguments + unquantified, free_vars = func_ty.unquantified() + args, subst = type_check_args(args, unquantified, {}, ctx, node) + + # Success implies that the substitution is closed + assert all(not t.free_vars for t in subst.values()) + inst = [subst[v.id] for v in free_vars] + return args, unquantified.returns.substitute(subst), inst def check_call( @@ -290,14 +418,42 @@ def check_call( ty: GuppyType, node: AstNode, ctx: Context, -) -> list[ast.expr]: - """Checks the return type of a function call against a given type""" - args, return_ty = synthesize_call(func_ty, args, node, ctx) - if return_ty != ty: + kind: str = "expression" +) -> tuple[list[ast.expr], Subst, Inst]: + """Checks the return type of a function call against a given type. + + Returns an annotated argument list, a substitution for the free variables in the + expected type, and an instantiation for the quantifiers in the function type. + """ + assert not func_ty.free_vars + check_num_args(len(func_ty.args), len(args), node) + + # Replace quantified variables with free unification variables and try to do some + # inference based on the expected return type + unquantified, free_vars = func_ty.unquantified() + subst = unify(ty, unquantified.returns, {}) + if subst is None: raise GuppyTypeError( - f"Expected expression of type `{ty}`, got `{return_ty}`", node + f"Expected {kind} of type `{ty}`, got `{unquantified.returns}`", node ) - return args + + # Try to infer more by checking against the arguments + args, subst = type_check_args(args, unquantified, subst, ctx, node) + + # Also make sure we found an instantiation for all free vars in the type we're + # checking against + if not set.issubset(set(ty.free_vars.keys()), subst.keys()): + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got " + f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", + node + ) + + # Success implies that the substitution is closed + assert all(not t.free_vars for t in subst.values()) + inst = [subst[v.id] for v in free_vars] + subst = {v: t for v, t in subst.items() if v in ty.free_vars} + return args, subst, inst def to_bool( diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 87b22b8f..5fd47c4d 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -8,7 +8,8 @@ from guppy.checker.cfg_checker import check_cfg, CheckedCFG from guppy.checker.expr_checker import synthesize_call, check_call from guppy.error import GuppyError -from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType +from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType, Subst, \ + BoundTypeVar from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef @@ -24,21 +25,25 @@ def from_ast( func_def: ast.FunctionDef, name: str, globals: Globals ) -> "DefinedFunction": ty = check_signature(func_def, globals) + if ty.quantified: + raise GuppyError( + "Generic function definitions are not supported yet", func_def + ) return DefinedFunction(name, ty, func_def, None) def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context - ) -> GlobalCall: + ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker - args = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args) + args, subst, inst = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker - args, ty = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args), ty + args, ty, inst = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), ty @dataclass @@ -122,7 +127,7 @@ def check_nested_func_def( if len(captured) == 0: # If there are no captured vars, we treat the function like a global name func = DefinedFunction(func_def.name, func_ty, func_def, None) - globals = ctx.globals | Globals({func_def.name: func}, {}) + globals = ctx.globals | Globals({func_def.name: func}, {}, {}) else: # Otherwise, we treat it like a local name @@ -167,14 +172,16 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType ) raise GuppyError("Return type must be annotated", func_def) + # TODO: Prepopulate mapping when using Python 3.12 style generic functions + type_var_mapping: dict[str, BoundTypeVar] = {} arg_tys = [] arg_names = [] for i, arg in enumerate(func_def.args.args): if arg.annotation is None: raise GuppyError("Argument type must be annotated", arg) - ty = type_from_ast(arg.annotation, globals) + ty = type_from_ast(arg.annotation, globals, type_var_mapping) arg_tys.append(ty) arg_names.append(arg.arg) - ret_type = type_from_ast(func_def.returns, globals) - return FunctionType(arg_tys, ret_type, arg_names) + ret_type = type_from_ast(func_def.returns, globals, type_var_mapping) + return FunctionType(arg_tys, ret_type, arg_names, sorted(type_var_mapping.values(), key=lambda v: v.idx)) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 5b12dc58..c7b0219c 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -6,7 +6,7 @@ from guppy.checker.core import Variable, Context from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType +from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType, Subst from guppy.nodes import NestedFunctionDef @@ -16,6 +16,7 @@ class StmtChecker(AstVisitor[BBStatement]): return_ty: GuppyType def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: + assert not return_ty.free_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -28,7 +29,7 @@ def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: def _check_expr( self, node: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: return ExprChecker(self.ctx).check(node, ty, kind) def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: @@ -79,7 +80,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: "Variable declaration is not supported. Assignment is required", node ) ty = type_from_ast(node.annotation, self.ctx.globals) - node.value = self._check_expr(node.value, ty) + node.value, subst = self._check_expr(node.value, ty) + assert not ty.free_vars and len(subst) == 0 # `ty` must be closed! self._check_assign(node.target, ty, node) return node @@ -97,7 +99,8 @@ def visit_Expr(self, node: ast.Expr) -> ast.stmt: def visit_Return(self, node: ast.Return) -> ast.stmt: if node.value is not None: - node.value = self._check_expr(node.value, self.return_ty, "return value") + node.value, subst = self._check_expr(node.value, self.return_ty, "return value") + assert len(subst) == 0 # `self.return_ty` is closed! elif not isinstance(self.return_ty, NoneType): raise GuppyTypeError( f"Expected return value of type `{self.return_ty}`", None diff --git a/guppy/custom.py b/guppy/custom.py index 55a5896e..4936487d 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -13,7 +13,7 @@ UnknownFunctionType, GuppyTypeError, ) -from guppy.gtypes import GuppyType, FunctionType +from guppy.gtypes import GuppyType, FunctionType, Subst from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode from guppy.nodes import GlobalCall @@ -86,9 +86,10 @@ def check_type(self, globals: Globals) -> None: def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: self.call_checker._setup(ctx, node, self) - return with_type(ty, with_loc(node, self.call_checker.check(args, ty))) + new_node, subst = self.call_checker.check(args, ty) + return with_type(ty, with_loc(node, new_node)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: "Context" @@ -166,15 +167,12 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: self.node = node self.func = func - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + @abstractmethod + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: """Checks the return value against a given type. Returns a (possibly) transformed and annotated AST node for the call. """ - args, actual = self.synthesize(args) - raise GuppyTypeError( - f"Expected expression of type `{ty}`, got `{actual}`", self.node - ) @abstractmethod def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: @@ -208,15 +206,15 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: class DefaultCallChecker(CustomCallChecker): """Checks function calls by comparing to a type signature.""" - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker - args = check_call(self.func.ty, args, ty, self.node, self.ctx) - return GlobalCall(func=self.func, args=args) + args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) + return GlobalCall(func=self.func, args=args, type_args=inst), subst def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: # Use default implementation from the expression checker - args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) - return GlobalCall(func=self.func, args=args), ty + args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) + return GlobalCall(func=self.func, args=args, type_args=inst), ty class DefaultCallCompiler(CustomCallCompiler): diff --git a/guppy/declared.py b/guppy/declared.py index 91b5283e..29bda97e 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -8,7 +8,7 @@ from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError -from guppy.gtypes import type_to_row, GuppyType +from guppy.gtypes import type_to_row, GuppyType, Subst from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV from guppy.nodes import GlobalCall @@ -32,17 +32,17 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context - ) -> GlobalCall: + ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker - args = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args) + args, subst, inst = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker - args, ty = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args), ty + args, ty, inst = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) diff --git a/guppy/decorator.py b/guppy/decorator.py index 493ddcb2..13171548 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -1,6 +1,6 @@ import functools from dataclasses import dataclass -from typing import Optional, Union, Callable, Any +from typing import Optional, Union, Callable, Any, Iterator, Sequence, ClassVar, TypeVar from guppy.ast_util import AstNode, has_empty_body from guppy.custom import ( @@ -12,7 +12,7 @@ DefaultCallCompiler, ) from guppy.error import GuppyError, pretty_errors -from guppy.gtypes import GuppyType +from guppy.gtypes import GuppyType, TypeTransformer from guppy.hugr import tys, ops from guppy.hugr.hugr import Hugr from guppy.module import GuppyModule, PyFunc, parse_py_func @@ -95,7 +95,8 @@ def dec(c: type) -> type: @dataclass(frozen=True) class NewType(GuppyType): - name = _name + args: Sequence[GuppyType] + name: ClassVar[str] = _name @staticmethod def build( @@ -106,7 +107,11 @@ def build( raise GuppyError( f"Type `{_name}` does not accept type parameters.", node ) - return NewType() + return NewType([]) + + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(self.args) @property def linear(self) -> bool: @@ -115,6 +120,11 @@ def linear(self) -> bool: def to_hugr(self) -> tys.SimpleType: return hugr_ty + def transform(self, transformer: TypeTransformer) -> GuppyType: + return transformer.transform(self) or NewType( + [ty.transform(transformer) for ty in self.args] + ) + def __str__(self) -> str: return _name @@ -127,6 +137,12 @@ def __str__(self) -> str: return dec + @pretty_errors + def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> TypeVar: + """Creates a new type variable in a module.""" + module.register_type_var(name, linear) + return TypeVar(name) + @pretty_errors def custom( self, diff --git a/guppy/error.py b/guppy/error.py index b0ace141..d06b9129 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -132,7 +132,9 @@ def format_source_location( source, line_offset = get_source(loc), get_line_offset(loc) assert source is not None and line_offset is not None source_lines = source.splitlines(keepends=True) - end_col_offset = loc.end_col_offset or len(source_lines[loc.lineno]) + end_col_offset = loc.end_col_offset + if end_col_offset is None or (loc.end_lineno and loc.end_lineno > loc.lineno): + end_col_offset = len(source_lines[loc.lineno - 1]) s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() s += "\n" + loc.col_offset * " " + (end_col_offset - loc.col_offset) * "^" s = textwrap.dedent(s).splitlines() diff --git a/guppy/gtypes.py b/guppy/gtypes.py index ca5c010a..b5345ff9 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -1,7 +1,10 @@ import ast +import functools +import itertools from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING +from typing import Optional, Sequence, TYPE_CHECKING, Mapping, Iterator, ClassVar, \ + Literal import guppy.hugr.tys as tys from guppy.ast_util import AstNode, set_location_from @@ -10,19 +13,63 @@ from guppy.checker.core import Globals +@dataclass(frozen=True) +class TypeVarId: + """Identifier for free type variables.""" + id: int + + _id_generator: ClassVar[Iterator[int]] = itertools.count() + + @classmethod + def new(cls) -> "TypeVarId": + return TypeVarId(next(cls._id_generator)) + + +Subst = dict[TypeVarId, "GuppyType"] +Inst = Sequence["GuppyType"] + + +@dataclass(frozen=True) class GuppyType(ABC): """Base class for all Guppy types. Note that all instances of `GuppyType` subclasses are expected to be immutable. """ - name: str = "" + name: ClassVar[str] + + # Cache for free variables + _free_vars: Mapping["TypeVarId", "FreeTypeVar"] = field(init=False, repr=False) + + def __post_init__(self) -> None: + # Make sure that we don't have higher-rank polymorphic types + for arg in self.type_args: + if isinstance(arg, FunctionType) and arg.quantified: + from guppy.error import InternalGuppyError + + raise InternalGuppyError( + "Tried to construct a higher-rank polymorphic type!" + ) + + # Compute free variables + if isinstance(self, FreeTypeVar): + vs = {self.id: self} + else: + vs: dict[TypeVarId, FreeTypeVar] = {} + for arg in self.type_args: + vs |= arg.free_vars + object.__setattr__(self, "_free_vars", vs) @staticmethod @abstractmethod def build(*args: "GuppyType", node: Optional[AstNode] = None) -> "GuppyType": pass + @property + @abstractmethod + def type_args(self) -> Iterator["GuppyType"]: + pass + @property @abstractmethod def linear(self) -> bool: @@ -32,6 +79,75 @@ def linear(self) -> bool: def to_hugr(self) -> tys.SimpleType: pass + @abstractmethod + def transform(self, transformer: "TypeTransformer") -> "GuppyType": + pass + + @property + def free_vars(self) -> Mapping["TypeVarId", "FreeTypeVar"]: + return self._free_vars + + def substitute(self, s: Subst) -> "GuppyType": + return self.transform(Substituter(s)) + + +@dataclass(frozen=True) +class BoundTypeVar(GuppyType): + """Bound type variable, identified with a de Bruijn index.""" + + idx: int + display_name: str + linear: bool = False + + name: ClassVar[Literal["BoundTypeVar"]] = "BoundTypeVar" + + @staticmethod + def build(*rgs: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + raise NotImplementedError() + + @property + def type_args(self) -> Iterator["GuppyType"]: + return iter(()) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self + + def __str__(self) -> str: + return self.display_name + + def to_hugr(self) -> tys.SimpleType: + raise NotImplementedError() + + +@dataclass(frozen=True) +class FreeTypeVar(GuppyType): + """Free type variable, identified with a globally unique id.""" + + id: TypeVarId + display_name: str + linear: bool = False + + name: ClassVar[Literal["FreeTypeVar"]] = "FreeTypeVar" + + @staticmethod + def build(*rgs: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + raise NotImplementedError() + + @property + def type_args(self) -> Iterator["GuppyType"]: + return iter(()) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self + + def __str__(self) -> str: + return "?" + self.display_name + + def to_hugr(self) -> tys.SimpleType: + from guppy.error import InternalGuppyError + + raise InternalGuppyError("Tried to convert free type variable to Hugr") + @dataclass(frozen=True) class FunctionType(GuppyType): @@ -41,16 +157,18 @@ class FunctionType(GuppyType): default=None, compare=False, # Argument names are not taken into account for type equality ) + quantified: Sequence[BoundTypeVar] = field(default_factory=list) - name: str = "->" + name: ClassVar[Literal["%function"]] = "%function" linear = False def __str__(self) -> str: + prefix = "forall " + ", ".join(str(v) for v in self.quantified) + ". " if self.quantified else "" if len(self.args) == 1: [arg] = self.args - return f"{arg} -> {self.returns}" + return prefix + f"{arg} -> {self.returns}" else: - return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" + return prefix + f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" @staticmethod def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: @@ -58,17 +176,46 @@ def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: # has a special case for function types. raise NotImplementedError() + @property + def type_args(self) -> Iterator[GuppyType]: + return itertools.chain(iter(self.args), iter((self.returns,))) + def to_hugr(self) -> tys.SimpleType: ins = [t.to_hugr() for t in self.args] outs = [t.to_hugr() for t in type_to_row(self.returns)] return tys.FunctionType(input=ins, output=outs, extension_reqs=[]) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or FunctionType( + [ty.transform(transformer) for ty in self.args], + self.returns.transform(transformer), + self.arg_names + ) + + def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": + """Instantiates quantified type variables.""" + assert len(tys) == len(self.quantified) + inst = Instantiator(tys) + return FunctionType( + [ty.transform(inst) for ty in self.args], + self.returns.transform(inst), + self.arg_names, + ) + + def unquantified(self) -> tuple["FunctionType", Sequence[FreeTypeVar]]: + """Replaces all quantified variables with free type variables.""" + inst = [ + FreeTypeVar(TypeVarId.new(), v.display_name, v.linear) + for v in self.quantified + ] + return self.instantiate(inst), inst + @dataclass(frozen=True) class TupleType(GuppyType): element_types: Sequence[GuppyType] - name: str = "tuple" + name: ClassVar[Literal["tuple"]] = "tuple" @staticmethod def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: @@ -86,15 +233,29 @@ def __str__(self) -> str: def linear(self) -> bool: return any(t.linear for t in self.element_types) + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(self.element_types) + + def substitute(self, s: Subst) -> GuppyType: + return TupleType([ty.substitute(s) for ty in self.element_types]) + def to_hugr(self) -> tys.SimpleType: ts = [t.to_hugr() for t in self.element_types] return tys.Tuple(inner=ts) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or TupleType( + [ty.transform(transformer) for ty in self.element_types] + ) + @dataclass(frozen=True) class SumType(GuppyType): element_types: Sequence[GuppyType] + name: ClassVar[str] = "%tuple" + @staticmethod def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: # Sum types cannot be parsed and constructed using `build` since they cannot be @@ -108,6 +269,13 @@ def __str__(self) -> str: def linear(self) -> bool: return any(t.linear for t in self.element_types) + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(self.element_types) + + def substitute(self, s: Subst) -> GuppyType: + return TupleType([ty.substitute(s) for ty in self.element_types]) + def to_hugr(self) -> tys.SimpleType: if all( isinstance(e, TupleType) and len(e.element_types) == 0 @@ -116,10 +284,15 @@ def to_hugr(self) -> tys.SimpleType: return tys.UnitSum(size=len(self.element_types)) return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or TupleType( + [ty.transform(transformer) for ty in self.element_types] + ) + @dataclass(frozen=True) class NoneType(GuppyType): - name: str = "None" + name: ClassVar[Literal["None"]] = "None" linear: bool = False @staticmethod @@ -130,19 +303,29 @@ def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: raise GuppyError("Type `None` is not generic", node) return NoneType() + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(()) + + def substitute(self, s: Subst) -> GuppyType: + return self + def __str__(self) -> str: return "None" def to_hugr(self) -> tys.SimpleType: return tys.Tuple(inner=[]) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self + @dataclass(frozen=True) class BoolType(SumType): """The type of booleans.""" - linear = False - name = "bool" + linear: bool = False + name: ClassVar[Literal["bool"]] = "bool" def __init__(self) -> None: # Hugr bools are encoded as Sum((), ()) @@ -159,34 +342,124 @@ def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: def __str__(self) -> str: return "bool" + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self -def _lookup_type(node: AstNode, globals: "Globals") -> Optional[type[GuppyType]]: - if isinstance(node, ast.Name) and node.id in globals.types: - return globals.types[node.id] - if isinstance(node, ast.Constant) and node.value is None: - return NoneType - if ( - isinstance(node, ast.Constant) - and isinstance(node.value, str) - and node.value in globals.types - ): - return globals.types[node.value] + +class TypeTransformer(ABC): + """Abstract base class for a type visitor that transforms types.""" + + @abstractmethod + def transform(self, ty: GuppyType) -> Optional[GuppyType]: + """This method is called for each visited type. + + Return a transformed type or `None` to continue the recursive visit. + """ + pass + + +class Substituter(TypeTransformer): + """Type transformer that substitutes free type variables.""" + + subst: Subst + + def __init__(self, subst: Subst) -> None: + self.subst = subst + + def transform(self, ty: GuppyType) -> Optional[GuppyType]: + if isinstance(ty, FreeTypeVar): + return self.subst.get(ty.id, None) + return None + + +class Instantiator(TypeTransformer): + """Type transformer that instantiates bound type variables.""" + + tys: Sequence[GuppyType] + + def __init__(self, tys: Sequence[GuppyType]) -> None: + self.tys = tys + + def transform(self, ty: GuppyType) -> Optional[GuppyType]: + if isinstance(ty, BoundTypeVar): + # Instantiate if type for the index is available + if ty.idx < len(self.tys): + return self.tys[ty.idx] + + # Otherwise, lower the de Bruijn index + return BoundTypeVar(ty.idx - len(self.tys), ty.display_name, ty.linear) + return None + + +def unify(s: GuppyType, t: GuppyType, subst: Optional[Subst]) -> Optional[Subst]: + """Computes a most general unifier for two types. + + Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this + not possible. + """ + if subst is None: + return None + if s == t: + return subst + if isinstance(s, FreeTypeVar): + return _unify_var(s, t, subst) + if isinstance(t, FreeTypeVar): + return _unify_var(t, s, subst) + if type(s) == type(t): + sargs, targs = list(s.type_args), list(t.type_args) + if len(sargs) == len(targs): + for sa, ta in zip(sargs, targs): + subst = unify(sa, ta, subst) + return subst return None -def type_from_ast(node: AstNode, globals: "Globals") -> GuppyType: +def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Optional[Subst]: + """Helper function for unification of type variables.""" + if var.id in subst: + return unify(subst[var.id], t, subst) + if isinstance(t, FreeTypeVar) and t.id in subst: + return unify(var, subst[t.id], subst) + if var.id in t.free_vars: + return None + return {var.id: t, **subst} + + +def type_from_ast(node: AstNode, globals: "Globals", type_var_mapping: Optional[dict[str, BoundTypeVar]] = None) -> GuppyType: """Turns an AST expression into a Guppy type.""" - if isinstance(node, ast.Name) and (ty := _lookup_type(node, globals)): - return ty.build(node=node) - if isinstance(node, ast.Constant) and (ty := _lookup_type(node, globals)): - name = ast.Name(id=node.value) - set_location_from(name, node) - return ty.build(node=name) - if isinstance(node, ast.Subscript) and (ty := _lookup_type(node.value, globals)): - args = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] - return ty.build(*(type_from_ast(a, globals) for a in args), node=node) + from guppy.error import GuppyError + + if isinstance(node, ast.Name): + x = node.id + if x in globals.types: + return globals.types[x].build(node=node) + if x in globals.type_vars: + if type_var_mapping is None: + raise GuppyError( + "Free type variable. Only function types can be generic", node + ) + var_decl = globals.type_vars[x] + if var_decl.name not in type_var_mapping: + type_var_mapping[var_decl.name] = BoundTypeVar( + len(type_var_mapping), var_decl.name, var_decl.linear + ) + return type_var_mapping[var_decl.name] + raise GuppyError("Unknown type", node) + + if isinstance(node, ast.Constant): + v = node.value + if v is None: + return NoneType() + if isinstance(v, str): + try: + return type_from_ast(ast.parse(v), globals, type_var_mapping) + except Exception: + raise GuppyError("Invalid Guppy type", node) + raise GuppyError(f"Constant `{v}` is not a valid type", node) + if isinstance(node, ast.Tuple): return TupleType([type_from_ast(el, globals) for el in node.elts]) + if ( isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name) @@ -194,13 +467,18 @@ def type_from_ast(node: AstNode, globals: "Globals") -> GuppyType: and isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2 ): + # TODO: Do we want to allow polymorphic Callable types? [func_args, ret] = node.slice.elts if isinstance(func_args, ast.List): return FunctionType( - [type_from_ast(a, globals) for a in func_args.elts], - type_from_ast(ret, globals), + [type_from_ast(a, globals, type_var_mapping) for a in func_args.elts], + type_from_ast(ret, globals, type_var_mapping), ) - from guppy.error import GuppyError + + if isinstance(node, ast.Subscript): + ty = type_from_ast(node.value, globals, type_var_mapping) + args = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + return ty.build(*(type_from_ast(a, globals) for a in args), node=node) raise GuppyError("Not a valid Guppy type", node) diff --git a/guppy/module.py b/guppy/module.py index e07a15ac..8047d6ca 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -6,7 +6,7 @@ from typing import Callable, Any, Optional, Union from guppy.ast_util import annotate_location, AstNode -from guppy.checker.core import Globals, qualified_name +from guppy.checker.core import Globals, qualified_name, TypeVarDecl from guppy.checker.func_checker import DefinedFunction, check_global_func_def from guppy.compiler.core import CompiledGlobals from guppy.compiler.func_compiler import compile_global_func_def, CompiledFunctionDef @@ -48,7 +48,7 @@ class GuppyModule: def __init__(self, name: str, import_builtins: bool = True): self.name = name - self._globals = Globals({}, {}) + self._globals = Globals({}, {}, {}) self._compiled_globals = {} self._imported_globals = Globals.default() self._imported_compiled_globals = {} @@ -124,8 +124,16 @@ def register_custom_func( def register_type(self, name: str, ty: type[GuppyType]) -> None: """Registers an existing Guppy type as belonging to this Guppy module.""" + self._check_not_yet_compiled() + self._check_type_name_available(name, None) self._globals.types[name] = ty + def register_type_var(self, name: str, linear: bool) -> None: + """Registers a new type variable""" + self._check_not_yet_compiled() + self._check_type_name_available(name, None) + self._globals.type_vars[name] = TypeVarDecl(name, linear) + def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: assert self._instance_func_buffer is not None buffer = self._instance_func_buffer @@ -207,6 +215,19 @@ def _check_name_available(self, name: str, node: Optional[AstNode]) -> None: node, ) + def _check_type_name_available(self, name: str, node: Optional[AstNode]) -> None: + if name in self._globals.types: + raise GuppyError( + f"Module `{self.name}` already contains a type `{name}`", + node, + ) + + if name in self._globals.type_vars: + raise GuppyError( + f"Module `{self.name}` already contains a type variable `{name}`", + node, + ) + def parse_py_func(f: PyFunc) -> ast.FunctionDef: source_lines, line_offset = inspect.getsourcelines(f) diff --git a/guppy/nodes.py b/guppy/nodes.py index dfd02349..967a4ce8 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -1,9 +1,9 @@ """Custom AST nodes used by Guppy""" import ast -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping, Sequence -from guppy.gtypes import FunctionType +from guppy.gtypes import FunctionType, GuppyType, Inst if TYPE_CHECKING: from guppy.cfg.cfg import CFG @@ -40,12 +40,22 @@ class LocalCall(ast.expr): class GlobalCall(ast.expr): func: "CallableVariable" args: list[ast.expr] - - # Later: Inferred type args + type_args: Inst # Inferred type arguments _fields = ( "func", "args", + "type_args", + ) + + +class TypeApply(ast.expr): + value: ast.expr + tys: Sequence[GuppyType] + + _fields = ( + "value", + "tys", ) From 497cd2328321be6a6bf46708ae6041d11f227fb4 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Fri, 24 Nov 2023 15:02:03 +0000 Subject: [PATCH 25/77] refactor: Move graph generation code (#59) * Add graph generation code * Rename predicate to TupleSum * Add function to add input node with ports --- guppy/compiler/cfg_compiler.py | 166 +++++++++++++++++++++++++++++++- guppy/compiler/expr_compiler.py | 116 ++++++++++++++++++++++ guppy/compiler/func_compiler.py | 69 ++++++++++++- guppy/compiler/stmt_compiler.py | 96 ++++++++++++++++++ guppy/custom.py | 16 ++- guppy/hugr/hugr.py | 7 ++ 6 files changed, 452 insertions(+), 18 deletions(-) create mode 100644 guppy/compiler/expr_compiler.py create mode 100644 guppy/compiler/stmt_compiler.py diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index cf4717eb..086236e9 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -1,17 +1,173 @@ -from guppy.checker.cfg_checker import CheckedBB, CheckedCFG -from guppy.compiler.core import CompiledGlobals -from guppy.hugr.hugr import Hugr, Node, CFNode +import functools +from typing import Sequence + +from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature +from guppy.checker.core import Variable +from guppy.compiler.core import ( + CompiledGlobals, + is_return_var, + DFContainer, + return_var, + PortVariable, +) +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.compiler.stmt_compiler import StmtCompiler +from guppy.gtypes import TupleType, SumType, type_to_row +from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV def compile_cfg( cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals ) -> None: """Compiles a CFG to Hugr.""" - raise NotImplementedError + insert_return_vars(cfg) + + blocks: dict[CheckedBB, CFNode] = {} + for bb in cfg.bbs: + blocks[bb] = compile_bb(bb, graph, parent, globals) + for bb in cfg.bbs: + for succ in bb.successors: + graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) def compile_bb( bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals ) -> CFNode: """Compiles a single basic block to Hugr.""" - raise NotImplementedError + inputs = sort_vars(bb.sig.input_row) + + # The exit BB is completely empty + if len(bb.successors) == 0: + assert len(bb.statements) == 0 + return graph.add_exit([v.ty for v in inputs], parent) + + # Otherwise, we use a regular `Block` node + block = graph.add_block(parent) + + # Add input node and compile the statements + inp = graph.add_input(output_tys=[v.ty for v in inputs], parent=block) + dfg = DFContainer( + block, + { + v.name: PortVariable(v.name, inp.out_port(i), v.defined_at, None) + for (i, v) in enumerate(inputs) + }, + ) + dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) + + # If we branch, we also have to compile the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg) + else: + # Even if we don't branch, we still have to add a `Sum(())` predicates + unit = graph.add_make_tuple([], parent=block).out_port(0) + branch_port = graph.add_tag( + variants=[TupleType([])], tag=0, inp=unit, parent=block + ).out_port(0) + + # Finally, we have to add the block output. + outputs: Sequence[Variable] + if len(bb.successors) == 1: + # The easy case is if we don't branch: We just output all variables that are + # specified by the signature + [outputs] = bb.sig.output_rows + else: + # If we branch and the branches use the same variables, then we can use a + # regular output + first, *rest = bb.sig.output_rows + if all({v.name for v in first} == {v.name for v in r} for r in rest): + outputs = first + else: + # Otherwise, we have to output a TupleSum: We put all non-linear variables + # into the branch TupleSum and all linear variables in the normal output + # (since they are shared between all successors). This is in line with the + # ordering on variables which puts linear variables at the end. The only + # exception are return vars which must be outputted in order. + branch_port = choose_vars_for_tuple_sum( + graph=graph, + unit_sum=branch_port, + output_vars=[ + [v for v in row if not v.ty.linear or is_return_var(v.name)] + for row in bb.sig.output_rows + ], + dfg=dfg, + ) + outputs = [v for v in first if v.ty.linear and not is_return_var(v.name)] + + graph.add_output( + inputs=[branch_port] + [dfg[v.name].port for v in sort_vars(outputs)], + parent=block, + ) + return block + + +def insert_return_vars(cfg: CheckedCFG) -> None: + """Patches a CFG by annotating dummy return variables in the BB signatures. + + The statement compiler turns `return` statements into assignments of dummy variables + `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are + correctly outputted. + """ + return_vars = [ + Variable(return_var(i), ty, None, None) + for i, ty in enumerate(type_to_row(cfg.output_ty)) + ] + # Before patching, the exit BB shouldn't take any inputs + assert len(cfg.exit_bb.sig.input_row) == 0 + cfg.exit_bb.sig = Signature(return_vars, cfg.exit_bb.sig.output_rows) + # Also patch the predecessors + for pred in cfg.exit_bb.predecessors: + # The exit BB will be the only successor + assert len(pred.sig.output_rows) == 1 and len(pred.sig.output_rows[0]) == 0 + pred.sig = Signature(pred.sig.input_row, [return_vars]) + + +def choose_vars_for_tuple_sum( + graph: Hugr, unit_sum: OutPortV, output_vars: list[VarRow], dfg: DFContainer +) -> OutPortV: + """Selects an output based on a TupleSum. + + Given `unit_sum: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, + constructs a TupleSum value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. + """ + assert isinstance(unit_sum.ty, SumType) + assert len(unit_sum.ty.element_types) == len(output_vars) + tuples = [ + graph.add_make_tuple( + inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], + parent=dfg.node, + ).out_port(0) + for vs in output_vars + ] + tys = [t.ty for t in tuples] + conditional = graph.add_conditional( + cond_input=unit_sum, inputs=tuples, parent=dfg.node + ) + for i, ty in enumerate(tys): + case = graph.add_case(conditional) + inp = graph.add_input(output_tys=tys, parent=case).out_port(i) + tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) + graph.add_output(inputs=[tag], parent=case) + return conditional.add_out_port(SumType(tys)) + + +def compare_var(x: Variable, y: Variable) -> int: + """Defines a `<` order on variables. + + We use this to determine in which order variables are outputted from basic blocks. + We need to output linear variables at the end, so we do a lexicographic ordering of + linearity and name. The only exception are return vars which must be outputted in + order. + """ + if is_return_var(x.name) and is_return_var(y.name): + return -1 if x.name < y.name else 1 + return -1 if (x.ty.linear, x.name) < (y.ty.linear, y.name) else 1 + + +def sort_vars(row: VarRow) -> list[Variable]: + """Sorts a row of variables. + + This determines the order in which they are outputted from a BB. + """ + return sorted(row, key=functools.cmp_to_key(compare_var)) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py new file mode 100644 index 00000000..5fdc378b --- /dev/null +++ b/guppy/compiler/expr_compiler.py @@ -0,0 +1,116 @@ +import ast +from typing import Any, Optional + +from guppy.ast_util import AstVisitor, get_type +from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction +from guppy.error import InternalGuppyError +from guppy.gtypes import FunctionType, type_to_row, BoolType +from guppy.hugr import ops, val +from guppy.hugr.hugr import OutPortV +from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall + + +class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): + """A compiler from Guppy expressions to Hugr.""" + + dfg: DFContainer + + def compile(self, expr: ast.expr, dfg: DFContainer) -> OutPortV: + """Compiles an expression and returns a single port holding the output value.""" + self.dfg = dfg + with self.graph.parent(dfg.node): + res = self.visit(expr) + return res + + def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: + """Compiles a row expression and returns a list of ports, one for each value in + the row. + + On Python-level, we treat tuples like rows on top-level. However, nested tuples + are treated like regular Guppy tuples. + """ + return [self.compile(e, dfg) for e in expr_to_row(expr)] + + def visit_Constant(self, node: ast.Constant) -> OutPortV: + if value := python_value_to_hugr(node.value): + const = self.graph.add_constant(value, get_type(node)).out_port(0) + return self.graph.add_load_constant(const).out_port(0) + raise InternalGuppyError("Unsupported constant expression in compiler") + + def visit_LocalName(self, node: LocalName) -> OutPortV: + return self.dfg[node.id].port + + def visit_GlobalName(self, node: GlobalName) -> OutPortV: + return self.globals[node.id].load(self.dfg, self.graph, self.globals, node) + + def visit_Name(self, node: ast.Name) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Tuple(self, node: ast.Tuple) -> OutPortV: + return self.graph.add_make_tuple( + inputs=[self.visit(e) for e in node.elts] + ).out_port(0) + + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: + """Groups function return values into a tuple""" + if len(returns) != 1: + return self.graph.add_make_tuple(inputs=returns).out_port(0) + return returns[0] + + def visit_LocalCall(self, node: LocalCall) -> OutPortV: + func = self.visit(node.func) + assert isinstance(func.ty, FunctionType) + + args = [self.visit(arg) for arg in node.args] + call = self.graph.add_indirect_call(func, args) + rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.returns)))] + return self._pack_returns(rets) + + def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: + func = self.globals[node.func.name] + assert isinstance(func, CompiledFunction) + + args = [self.visit(arg) for arg in node.args] + rets = func.compile_call(args, self.dfg, self.graph, self.globals, node) + return self._pack_returns(rets) + + def visit_Call(self, node: ast.Call) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: + # The only case that is not desugared by the type checker is the `not` operation + # since it is not implemented via a dunder method + if isinstance(node.op, ast.Not): + arg = self.visit(node.operand) + return self.graph.add_node( + ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] + ).add_out_port(BoolType()) + + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_BinOp(self, node: ast.BinOp) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Compare(self, node: ast.Compare) -> OutPortV: + raise InternalGuppyError("Node should have been removed during type checking.") + + +def expr_to_row(expr: ast.expr) -> list[ast.expr]: + """Turns an expression into a row expressions by unpacking top-level tuples.""" + return expr.elts if isinstance(expr, ast.Tuple) else [expr] + + +def python_value_to_hugr(v: Any) -> Optional[val.Value]: + """Turns a Python value into a Hugr value. + + Returns None if the Python value cannot be represented in Guppy. + """ + from guppy.prelude._internal import int_value, bool_value, float_value + + if isinstance(v, bool): + return bool_value(v) + elif isinstance(v, int): + return int_value(v) + elif isinstance(v, float): + return float_value(v) + return None diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index bdc102df..4631f0db 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -2,12 +2,14 @@ from guppy.ast_util import AstNode from guppy.checker.func_checker import CheckedFunction, DefinedFunction +from guppy.compiler.cfg_compiler import compile_cfg from guppy.compiler.core import ( CompiledFunction, CompiledGlobals, DFContainer, PortVariable, ) +from guppy.gtypes import type_to_row, FunctionType from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef @@ -19,7 +21,7 @@ class CompiledFunctionDef(DefinedFunction, CompiledFunction): def load( self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode ) -> OutPortV: - raise NotImplementedError + return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) def compile_call( self, @@ -29,7 +31,8 @@ def compile_call( globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - raise NotImplementedError + call = graph.add_call(self.node.out_port(0), args, dfg.node) + return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] def compile_global_func_def( @@ -39,7 +42,17 @@ def compile_global_func_def( globals: CompiledGlobals, ) -> CompiledFunctionDef: """Compiles a top-level function definition to Hugr.""" - raise NotImplementedError + _, ports = graph.add_input_with_ports(list(func.ty.args), def_node) + cfg_node = graph.add_cfg(def_node, ports) + compile_cfg(func.cfg, graph, cfg_node, globals) + + # Add output node for the cfg + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) + + return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) def compile_local_func_def( @@ -49,4 +62,52 @@ def compile_local_func_def( globals: CompiledGlobals, ) -> PortVariable: """Compiles a local (nested) function definition to Hugr.""" - raise NotImplementedError + assert func.ty.arg_names is not None + + # Pick an order for the captured variables + captured = list(func.captured.values()) + + # Prepend captured variables to the function arguments + closure_ty = FunctionType( + [v.ty for v in captured] + list(func.ty.args), + func.ty.returns, + [v.name for v in captured] + list(func.ty.arg_names), + ) + + def_node = graph.add_def(closure_ty, dfg.node, func.name) + def_input, input_ports = graph.add_input_with_ports(list(closure_ty.args), def_node) + + # If we have captured variables and the body contains a recursive occurrence of + # the function itself, then we provide the partially applied function as a local + # variable + if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: + loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0) + partial = graph.add_partial( + loaded, [def_input.out_port(i) for i in range(len(captured))], def_node + ) + input_ports.append(partial.out_port(0)) + func.cfg.input_tys.append(func.ty) + else: + # Otherwise, we treat the function like a normal global variable + globals = globals | { + func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node) + } + + # Compile the CFG + cfg_node = graph.add_cfg(def_node, inputs=input_ports) + compile_cfg(func.cfg, graph, cfg_node, globals) + + # Add output node for the cfg + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) + + # Finally, load the function into the local data-flow graph + loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) + if len(captured) > 0: + loaded = graph.add_partial( + loaded, [dfg[v.name].port for v in captured], dfg.node + ).out_port(0) + + return PortVariable(func.name, loaded, func) diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py new file mode 100644 index 00000000..db656b6f --- /dev/null +++ b/guppy/compiler/stmt_compiler.py @@ -0,0 +1,96 @@ +import ast +from typing import Sequence + +from guppy.ast_util import AstVisitor +from guppy.checker.cfg_checker import CheckedBB +from guppy.compiler.core import ( + CompilerBase, + DFContainer, + CompiledGlobals, + PortVariable, + return_var, +) +from guppy.compiler.expr_compiler import ExprCompiler +from guppy.error import InternalGuppyError +from guppy.gtypes import TupleType +from guppy.hugr.hugr import Hugr, OutPortV +from guppy.nodes import CheckedNestedFunctionDef + + +class StmtCompiler(CompilerBase, AstVisitor[None]): + """A compiler for Guppy statements to Hugr""" + + expr_compiler: ExprCompiler + + bb: CheckedBB + dfg: DFContainer + + def __init__(self, graph: Hugr, globals: CompiledGlobals): + super().__init__(graph, globals) + self.expr_compiler = ExprCompiler(graph, globals) + + def compile_stmts( + self, + stmts: Sequence[ast.stmt], + bb: CheckedBB, + dfg: DFContainer, + ) -> DFContainer: + """Compiles a list of basic statements into a dataflow node. + + Note that the `dfg` is mutated in-place. After compilation, the DFG will also + contain all variables that are assigned in the given list of statements. + """ + self.bb = bb + self.dfg = dfg + for s in stmts: + self.visit(s) + return self.dfg + + def _unpack_assign(self, lhs: ast.expr, port: OutPortV, node: ast.stmt) -> None: + """Updates the local DFG with assignments.""" + if isinstance(lhs, ast.Name): + x = lhs.id + self.dfg[x] = PortVariable(x, port, node) + elif isinstance(lhs, ast.Tuple): + unpack = self.graph.add_unpack_tuple(port, self.dfg.node) + for i, pat in enumerate(lhs.elts): + self._unpack_assign(pat, unpack.out_port(i), node) + else: + raise InternalGuppyError("Invalid assign pattern in compiler") + + def visit_Assign(self, node: ast.Assign) -> None: + [target] = node.targets + port = self.expr_compiler.compile(node.value, self.dfg) + self._unpack_assign(target, port, node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + assert node.value is not None + port = self.expr_compiler.compile(node.value, self.dfg) + self._unpack_assign(node.target, port, node) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + raise InternalGuppyError("Node should have been removed during type checking.") + + def visit_Expr(self, node: ast.Expr) -> None: + self.expr_compiler.compile_row(node.value, self.dfg) + + def visit_Return(self, node: ast.Return) -> None: + # We turn returns into assignments of dummy variables, i.e. the statement + # `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`. + if node.value is not None: + port = self.expr_compiler.compile(node.value, self.dfg) + if isinstance(port.ty, TupleType): + unpack = self.graph.add_unpack_tuple(port, self.dfg.node) + row = [unpack.out_port(i) for i in range(len(port.ty.element_types))] + else: + row = [port] + for i, port in enumerate(row): + name = return_var(i) + self.dfg[name] = PortVariable(name, port, node) + + def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: + from guppy.compiler.func_compiler import compile_local_func_def + + self.dfg[node.name] = compile_local_func_def( + node, self.dfg, self.graph, self.globals + ) diff --git a/guppy/custom.py b/guppy/custom.py index 4027232f..6ed5ae0e 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional -from guppy.ast_util import AstNode, with_type, with_loc +from guppy.ast_util import AstNode, with_type, with_loc, get_type from guppy.checker.core import Context, Globals from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals @@ -12,7 +12,7 @@ UnknownFunctionType, GuppyTypeError, ) -from guppy.gtypes import GuppyType, FunctionType +from guppy.gtypes import GuppyType, FunctionType, type_to_row from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode @@ -135,13 +135,9 @@ def load( # to the function, and returns the results. if module not in self._defined: def_node = graph.add_def(self.ty, module, self.name) - inp = graph.add_input(list(self.ty.args), parent=def_node) + _, inp_ports = graph.add_input_with_ports(list(self.ty.args), def_node) returns = self.compile_call( - [inp.out_port(i) for i in range(len(self.ty.args))], - DFContainer(def_node, {}), - graph, - globals, - node, + inp_ports, DFContainer(def_node, {}), graph, globals, node ) graph.add_output(returns, parent=def_node) self._defined[module] = def_node @@ -227,4 +223,6 @@ def __init__(self, op: ops.OpType) -> None: self.op = op def compile(self, args: list[OutPortV]) -> list[OutPortV]: - raise NotImplementedError + node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) + return_ty = get_type(self.node) + return [node.add_out_port(ty) for ty in type_to_row(return_ty)] diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 511fe702..35880c29 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -375,6 +375,13 @@ def add_input( parent.input_child = node return node + def add_input_with_ports( + self, output_tys: TypeList, parent: Optional[Node] = None + ) -> tuple[VNode, list[OutPortV]]: + """Adds an `Input` node to the graph.""" + node = self.add_input(output_tys, parent) + return node, [node.add_out_port(ty) for ty in output_tys] + def add_output( self, inputs: Optional[list[OutPortV]] = None, From a3ba236e3687b41eb69698d026f11913f5c1d4d3 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 24 Nov 2023 15:13:39 +0000 Subject: [PATCH 26/77] Fix add_input_with_ports --- guppy/hugr/hugr.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 35880c29..59a00e97 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Optional, Iterator, Tuple, Any +from typing import Optional, Iterator, Tuple, Any, Sequence from dataclasses import field, dataclass import guppy.hugr.ops as ops @@ -376,11 +376,12 @@ def add_input( return node def add_input_with_ports( - self, output_tys: TypeList, parent: Optional[Node] = None + self, output_tys: Sequence[GuppyType], parent: Optional[Node] = None ) -> tuple[VNode, list[OutPortV]]: """Adds an `Input` node to the graph.""" - node = self.add_input(output_tys, parent) - return node, [node.add_out_port(ty) for ty in output_tys] + node = self.add_input(list(output_tys), parent) + ports = [node.add_out_port(ty) for ty in output_tys] + return node, ports def add_output( self, From 0b4ee03269c2efdf854c0d77dbb13911356eab2b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 24 Nov 2023 15:15:25 +0000 Subject: [PATCH 27/77] Make CustomCallChecker.check abstract again --- guppy/custom.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/guppy/custom.py b/guppy/custom.py index f17eebbb..f0c85c79 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -162,15 +162,12 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: self.node = node self.func = func + @abstractmethod def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: """Checks the return value against a given type. Returns a (possibly) transformed and annotated AST node for the call. """ - args, actual = self.synthesize(args) - raise GuppyTypeError( - f"Expected expression of type `{ty}`, got `{actual}`", self.node - ) @abstractmethod def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: From c26aa2e4e7b0eebb8c25b5c02067242e4fa0d2a1 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 24 Nov 2023 15:57:39 +0000 Subject: [PATCH 28/77] Fix type parsing --- guppy/gtypes.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/guppy/gtypes.py b/guppy/gtypes.py index b5345ff9..7ce8745c 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -116,7 +116,8 @@ def __str__(self) -> str: return self.display_name def to_hugr(self) -> tys.SimpleType: - raise NotImplementedError() + # TODO + return NoneType().to_hugr() @dataclass(frozen=True) @@ -475,10 +476,11 @@ def type_from_ast(node: AstNode, globals: "Globals", type_var_mapping: Optional[ type_from_ast(ret, globals, type_var_mapping), ) - if isinstance(node, ast.Subscript): - ty = type_from_ast(node.value, globals, type_var_mapping) - args = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] - return ty.build(*(type_from_ast(a, globals) for a in args), node=node) + if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): + x = node.value.id + if x in globals.types: + args = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + return globals.types[x].build(*(type_from_ast(a, globals) for a in args), node=node) raise GuppyError("Not a valid Guppy type", node) From 35739a8b767b0565dd6498c89479ab78c1538b2d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 24 Nov 2023 15:57:54 +0000 Subject: [PATCH 29/77] Compile TypeApply as dummy node --- guppy/compiler/expr_compiler.py | 6 +++++- guppy/hugr/hugr.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 5fdc378b..3dac0690 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -7,7 +7,7 @@ from guppy.gtypes import FunctionType, type_to_row, BoolType from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV -from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall +from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall, TypeApply class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -77,6 +77,10 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_TypeApply(self, node: TypeApply) -> OutPortV: + func = self.visit(node.value) + return self.graph.add_type_apply(func, node.tys, self.dfg.node).out_port(0) + def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: # The only case that is not desugared by the type checker is the `not` operation # since it is not implemented via a dunder method diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 59a00e97..0eca0f20 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -14,7 +14,7 @@ FunctionType, SumType, type_to_row, - row_to_type, + row_to_type, Inst, ) from guppy.hugr import val @@ -531,6 +531,19 @@ def add_partial( ops.DummyOp(name="partial"), None, [new_ty], parent, args + [def_port] ) + def add_type_apply( + self, func_port: OutPortV, tys: Inst, parent: Optional[Node] = None + ) -> VNode: + """Adds a `TypeApply` node to the graph.""" + assert isinstance(func_port.ty, FunctionType) + assert len(func_port.ty.quantified) == len(tys) + return self.add_node( + ops.DummyOp(name="TypeApply"), + inputs=[func_port], + output_types=[func_port.ty.instantiate(tys)], + parent=parent, + ) + def add_def( self, fun_ty: FunctionType, parent: Optional[Node], name: str ) -> DFContainingVNode: From 8cf7e1ea2029c6b1218bd8bf87d12f7dbdfc646f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 09:31:30 +0000 Subject: [PATCH 30/77] Try synthesis before checking calls --- guppy/checker/expr_checker.py | 36 +++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 350a0ba6..0dd3ea20 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -428,8 +428,40 @@ def check_call( assert not func_ty.free_vars check_num_args(len(func_ty.args), len(args), node) - # Replace quantified variables with free unification variables and try to do some - # inference based on the expected return type + # When checking, we can use the information from the expected return type to infer + # some type arguments. However, this pushes errors inwards. For example, given a + # function `foo: forall T. T -> T`, the following type mismatch would be reported: + # + # x: int = foo(None) + # ^^^^ Expected argument of type `int`, got `None` + # + # The following error location would be easier to understand for users: + # + # x: int = foo(None) + # ^^^^^^^^^ Expected expression of type `int`, got `None` + # + # Therefore, we should only resort to using the type information for inference if + # the regular synthesis method doesn't succeed. + + # TODO: This approach can result in exponential runtime in the worst case. However + # the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common in + # practice + + # First, try to synthesize + res: Optional[tuple[GuppyType, Inst]] = None + try: + args, synth, inst = synthesize_call(func_ty, args, node, ctx) + res = synth, inst + except GuppyTypeError: + pass + if res is not None: + synth, inst = res + subst = unify(ty, synth, {}) + if subst is None: + raise GuppyTypeError(f"Expected {kind} of type `{ty}`, got `{synth}`", node) + return args, subst, inst + + # Only if synthesis fails, we try to infer more from the return type unquantified, free_vars = func_ty.unquantified() subst = unify(ty, unquantified.returns, {}) if subst is None: From a9238ec17d96f8153f78d6e1ce39e5101de20cd9 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 09:32:36 +0000 Subject: [PATCH 31/77] Fix annotated assign statement --- guppy/cfg/bb.py | 5 +++++ guppy/cfg/builder.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index e3543204..eb6e0ff9 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -103,6 +103,11 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None: for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + self.stats.update_used(node.value) + for name in name_nodes_in_ast(node.target): + self.stats.assigned[name.id] = node + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # In order to compute the used external variables in a nested function # definition, we have to run live variable analysis first diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 7285a613..71f1da1c 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -96,6 +96,11 @@ def visit_AugAssign( ) -> Optional[BB]: return self._build_node_value(node, bb) + def visit_AnnAssign( + self, node: ast.AnnAssign, bb: BB, jumps: Jumps + ) -> Optional[BB]: + return self._build_node_value(node, bb) + def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]: # This is an expression statement where the value is discarded node.value, bb = ExprBuilder.build(node.value, self.cfg, bb) From 3af488fd4d587b5239b882d470dbc111b55d7e43 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 10:25:51 +0000 Subject: [PATCH 32/77] Factor out polymorphic checking code --- guppy/checker/expr_checker.py | 85 +++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 0dd3ea20..75f27863 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -134,41 +134,12 @@ def generic_visit( # type: ignore ) -> tuple[ast.expr, Subst]: # Try to synthesize and then check if we can unify it with the given type node, synth = self._synthesize(node, allow_free_vars=False) + subst, inst = check_type_against(synth, ty, node, self._kind) - # Special case if we synthesized a polymorphic function type. In that case, we - # have to find an instantiation to avoid higher-rank types. - subst: Optional[Subst] - if isinstance(synth, FunctionType) and synth.quantified: - unquantified, free_vars = synth.unquantified() - subst = unify(ty, unquantified, {}) - if subst is None: - self._fail(ty, synth, node) - # Check that we have found a valid instantiation for all quantified vars - for i, v in enumerate(free_vars): - if v.id not in subst: - raise GuppyTypeError( - f"Expected {self._kind} of type `{ty}`, got `{synth}`. " - "Couldn't infer an instantiation for type variable " - f"`{synth.quantified[i]}`", - node - ) - if subst[v.id].free_vars: - raise GuppyTypeError( - f"Expected {self._kind} of type `{ty}`, got `{synth}`. Can't " - f"instantiate type variable `{synth.quantified[i]}` with type " - f"`{subst[v.id]}` containing free variables", - node - ) - inst = [subst[v.id] for v in free_vars] + # Apply instantiation of quantified type variables + if inst: node = with_loc(node, TypeApply(value=node, tys=inst)) - subst = {v: t for v, t in subst.items() if v in ty.free_vars} - return node, subst - # Otherwise, we know that `synth` has no free type vars, so unification is - # trivial - subst = unify(ty, synth, {}) - if subst is None: - self._fail(ty, synth, node) return node, subst @@ -357,6 +328,52 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: ) +def check_type_against(act: GuppyType, exp: GuppyType, node: AstNode, kind: str = "expression") -> tuple[Subst, Inst]: + """Checks a type against another type. + + Returns a substitution for the free variables the expected type and an instantiation + for the quantified variables in the actual type. + """ + # Expected type may not be quantified + assert not isinstance(exp, FunctionType) or not exp.quantified + assert not act.free_vars + + # However, the actual type may be. In that case, we have to find an instantiation to + # avoid higher-rank types. + subst: Optional[Subst] + if isinstance(act, FunctionType) and act.quantified: + unquantified, free_vars = act.unquantified() + subst = unify(exp, unquantified, {}) + if subst is None: + raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) + # Check that we have found a valid instantiation for all quantified vars + for i, v in enumerate(free_vars): + if v.id not in subst: + raise GuppyTypeError( + f"Expected {kind} of type `{exp}`, got `{act}`. Couldn't infer an " + f"instantiation for type variable `{act.quantified[i]}` (higher-" + "rank polymorphic types are not supported)", + node + ) + if subst[v.id].free_vars: + raise GuppyTypeError( + f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " + f"type variable `{act.quantified[i]}` with type `{subst[v.id]}` " + "containing free variables", + node + ) + inst = [subst[v.id] for v in free_vars] + subst = {v: t for v, t in subst.items() if v in exp.free_vars} + return subst, inst + + # Otherwise, we know that `act` has no free type vars, so unification is trivial + assert not act.free_vars + subst = unify(exp, act, {}) + if subst is None: + raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) + return subst, [] + + def type_check_args(args: list[ast.expr], func_ty: FunctionType, subst: Subst, ctx: Context, node: AstNode) -> tuple[list[ast.expr], Subst]: """Checks the arguments of a function call and infers free type variables. @@ -379,7 +396,7 @@ def type_check_args(args: list[ast.expr], func_ty: FunctionType, subst: Subst, c for arg in func_ty.args ) - # We just have to check that we found instantiations for all vars in the return type + # We also have to check that we found instantiations for all vars in the return type if not set.issubset(set(func_ty.returns.free_vars.keys()), subst.keys()): raise GuppyTypeError( f"Cannot infer type variable in expression of type " @@ -435,7 +452,7 @@ def check_call( # x: int = foo(None) # ^^^^ Expected argument of type `int`, got `None` # - # The following error location would be easier to understand for users: + # The following error location would be more intuitive for users: # # x: int = foo(None) # ^^^^^^^^^ Expected expression of type `int`, got `None` From ed23fa3b2d41eb94babedc64d974a3038f508f95 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 10:33:17 +0000 Subject: [PATCH 33/77] Fix polymorphic type parsing --- guppy/gtypes.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 7ce8745c..e7a8cff5 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -3,8 +3,15 @@ import itertools from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING, Mapping, Iterator, ClassVar, \ - Literal +from typing import ( + Optional, + Sequence, + TYPE_CHECKING, + Mapping, + Iterator, + ClassVar, + Literal, +) import guppy.hugr.tys as tys from guppy.ast_util import AstNode, set_location_from @@ -16,6 +23,7 @@ @dataclass(frozen=True) class TypeVarId: """Identifier for free type variables.""" + id: int _id_generator: ClassVar[Iterator[int]] = itertools.count() @@ -164,12 +172,18 @@ class FunctionType(GuppyType): linear = False def __str__(self) -> str: - prefix = "forall " + ", ".join(str(v) for v in self.quantified) + ". " if self.quantified else "" + prefix = ( + "forall " + ", ".join(str(v) for v in self.quantified) + ". " + if self.quantified + else "" + ) if len(self.args) == 1: [arg] = self.args return prefix + f"{arg} -> {self.returns}" else: - return prefix + f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" + return ( + prefix + f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" + ) @staticmethod def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: @@ -190,7 +204,7 @@ def transform(self, transformer: "TypeTransformer") -> GuppyType: return transformer.transform(self) or FunctionType( [ty.transform(transformer) for ty in self.args], self.returns.transform(transformer), - self.arg_names + self.arg_names, ) def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": @@ -426,7 +440,11 @@ def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Optional[Subst]: return {var.id: t, **subst} -def type_from_ast(node: AstNode, globals: "Globals", type_var_mapping: Optional[dict[str, BoundTypeVar]] = None) -> GuppyType: +def type_from_ast( + node: AstNode, + globals: "Globals", + type_var_mapping: Optional[dict[str, BoundTypeVar]] = None, +) -> GuppyType: """Turns an AST expression into a Guppy type.""" from guppy.error import GuppyError @@ -459,7 +477,9 @@ def type_from_ast(node: AstNode, globals: "Globals", type_var_mapping: Optional[ raise GuppyError(f"Constant `{v}` is not a valid type", node) if isinstance(node, ast.Tuple): - return TupleType([type_from_ast(el, globals) for el in node.elts]) + return TupleType( + [type_from_ast(el, globals, type_var_mapping) for el in node.elts] + ) if ( isinstance(node, ast.Subscript) @@ -479,8 +499,12 @@ def type_from_ast(node: AstNode, globals: "Globals", type_var_mapping: Optional[ if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): x = node.value.id if x in globals.types: - args = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] - return globals.types[x].build(*(type_from_ast(a, globals) for a in args), node=node) + args = ( + node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + ) + return globals.types[x].build( + *(type_from_ast(a, globals, type_var_mapping) for a in args), node=node + ) raise GuppyError("Not a valid Guppy type", node) From d252bca27a07c8287776ed3c044be8cea36be0a9 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 10:33:29 +0000 Subject: [PATCH 34/77] Run formatting --- guppy/checker/core.py | 5 ++-- guppy/checker/expr_checker.py | 48 ++++++++++++++++++++++++----------- guppy/checker/func_checker.py | 17 ++++++++++--- guppy/checker/stmt_checker.py | 4 ++- guppy/hugr/hugr.py | 3 ++- 5 files changed, 55 insertions(+), 22 deletions(-) diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 6c9c8957..1659b284 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -10,7 +10,8 @@ TupleType, SumType, NoneType, - BoolType, Subst, + BoolType, + Subst, ) @@ -90,7 +91,7 @@ def __or__(self, other: "Globals") -> "Globals": return Globals( self.values | other.values, self.types | other.types, - self.type_vars | other.type_vars + self.type_vars | other.type_vars, ) def __ior__(self, other: "Globals") -> "Globals": diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 75f27863..fc4d4e9f 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -4,8 +4,16 @@ from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType, Subst, \ - FreeTypeVar, unify, Inst +from guppy.gtypes import ( + GuppyType, + TupleType, + FunctionType, + BoolType, + Subst, + FreeTypeVar, + unify, + Inst, +) from guppy.nodes import LocalName, GlobalName, LocalCall, TypeApply # Mapping from unary AST op to dunder method and display name @@ -89,7 +97,9 @@ def check( self._kind = old_kind return with_type(ty.substitute(subst), expr), subst - def _synthesize(self, node: ast.expr, allow_free_vars: bool) -> tuple[ast.expr, GuppyType]: + def _synthesize( + self, node: ast.expr, allow_free_vars: bool + ) -> tuple[ast.expr, GuppyType]: """Invokes the type synthesiser""" return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) @@ -149,7 +159,9 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): def __init__(self, ctx: Context) -> None: self.ctx = ctx - def synthesize(self, node: ast.expr, allow_free_vars: bool = False) -> tuple[ast.expr, GuppyType]: + def synthesize( + self, node: ast.expr, allow_free_vars: bool = False + ) -> tuple[ast.expr, GuppyType]: """Tries to synthesise a type for the given expression. Also returns a new desugared expression with type annotations. @@ -159,8 +171,7 @@ def synthesize(self, node: ast.expr, allow_free_vars: bool = False) -> tuple[ast node, ty = self.visit(node) if ty.free_vars and not allow_free_vars: raise GuppyTypeError( - f"Cannot infer type variable in expression of type `{ty}`", - node + f"Cannot infer type variable in expression of type `{ty}`", node ) return with_type(ty, node), ty @@ -328,7 +339,9 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: ) -def check_type_against(act: GuppyType, exp: GuppyType, node: AstNode, kind: str = "expression") -> tuple[Subst, Inst]: +def check_type_against( + act: GuppyType, exp: GuppyType, node: AstNode, kind: str = "expression" +) -> tuple[Subst, Inst]: """Checks a type against another type. Returns a substitution for the free variables the expected type and an instantiation @@ -353,14 +366,14 @@ def check_type_against(act: GuppyType, exp: GuppyType, node: AstNode, kind: str f"Expected {kind} of type `{exp}`, got `{act}`. Couldn't infer an " f"instantiation for type variable `{act.quantified[i]}` (higher-" "rank polymorphic types are not supported)", - node + node, ) if subst[v.id].free_vars: raise GuppyTypeError( f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " f"type variable `{act.quantified[i]}` with type `{subst[v.id]}` " "containing free variables", - node + node, ) inst = [subst[v.id] for v in free_vars] subst = {v: t for v, t in subst.items() if v in exp.free_vars} @@ -374,7 +387,13 @@ def check_type_against(act: GuppyType, exp: GuppyType, node: AstNode, kind: str return subst, [] -def type_check_args(args: list[ast.expr], func_ty: FunctionType, subst: Subst, ctx: Context, node: AstNode) -> tuple[list[ast.expr], Subst]: +def type_check_args( + args: list[ast.expr], + func_ty: FunctionType, + subst: Subst, + ctx: Context, + node: AstNode, +) -> tuple[list[ast.expr], Subst]: """Checks the arguments of a function call and infers free type variables. We expect that quantified variables have been replaced with free unification @@ -392,8 +411,7 @@ def type_check_args(args: list[ast.expr], func_ty: FunctionType, subst: Subst, c # If the argument check succeeded, this means that we must have found instantiations # for all unification variables occurring in the argument types assert all( - set.issubset(set(arg.free_vars.keys()), subst.keys()) - for arg in func_ty.args + set.issubset(set(arg.free_vars.keys()), subst.keys()) for arg in func_ty.args ) # We also have to check that we found instantiations for all vars in the return type @@ -401,7 +419,7 @@ def type_check_args(args: list[ast.expr], func_ty: FunctionType, subst: Subst, c raise GuppyTypeError( f"Cannot infer type variable in expression of type " f"`{func_ty.returns.substitute(subst)}`", - node + node, ) return new_args, subst @@ -435,7 +453,7 @@ def check_call( ty: GuppyType, node: AstNode, ctx: Context, - kind: str = "expression" + kind: str = "expression", ) -> tuple[list[ast.expr], Subst, Inst]: """Checks the return type of a function call against a given type. @@ -495,7 +513,7 @@ def check_call( raise GuppyTypeError( f"Expected expression of type `{ty}`, got " f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", - node + node, ) # Success implies that the substitution is closed diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 5fd47c4d..eeb55c1b 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -8,8 +8,14 @@ from guppy.checker.cfg_checker import check_cfg, CheckedCFG from guppy.checker.expr_checker import synthesize_call, check_call from guppy.error import GuppyError -from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType, Subst, \ - BoundTypeVar +from guppy.gtypes import ( + FunctionType, + type_from_ast, + NoneType, + GuppyType, + Subst, + BoundTypeVar, +) from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef @@ -184,4 +190,9 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType arg_names.append(arg.arg) ret_type = type_from_ast(func_def.returns, globals, type_var_mapping) - return FunctionType(arg_tys, ret_type, arg_names, sorted(type_var_mapping.values(), key=lambda v: v.idx)) + return FunctionType( + arg_tys, + ret_type, + arg_names, + sorted(type_var_mapping.values(), key=lambda v: v.idx), + ) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index c7b0219c..4e9535f7 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -99,7 +99,9 @@ def visit_Expr(self, node: ast.Expr) -> ast.stmt: def visit_Return(self, node: ast.Return) -> ast.stmt: if node.value is not None: - node.value, subst = self._check_expr(node.value, self.return_ty, "return value") + node.value, subst = self._check_expr( + node.value, self.return_ty, "return value" + ) assert len(subst) == 0 # `self.return_ty` is closed! elif not isinstance(self.return_ty, NoneType): raise GuppyTypeError( diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 0eca0f20..b76c0aaa 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -14,7 +14,8 @@ FunctionType, SumType, type_to_row, - row_to_type, Inst, + row_to_type, + Inst, ) from guppy.hugr import val From 1dd9d2b9414ef811d348da889d0ef9463e64ac5e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 11:21:09 +0000 Subject: [PATCH 35/77] Improve docstrings and specialise inference error --- guppy/checker/expr_checker.py | 64 ++++++++++++++++++----------------- guppy/error.py | 6 ++++ 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index fc4d4e9f..9dc52d86 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -3,7 +3,8 @@ from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError, \ + GuppyTypeInferenceError from guppy.gtypes import ( GuppyType, TupleType, @@ -325,34 +326,20 @@ def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: ) -def check_num_args(exp: int, act: int, node: AstNode) -> None: - """Checks that the correct number of arguments have been passed to a function.""" - if act < exp: - raise GuppyTypeError( - f"Not enough arguments passed (expected {exp}, got {act})", node - ) - if exp < act: - if isinstance(node, ast.Call): - raise GuppyTypeError("Unexpected argument", node.args[exp]) - raise GuppyTypeError( - f"Too many arguments passed (expected {exp}, got {act})", node - ) - - def check_type_against( act: GuppyType, exp: GuppyType, node: AstNode, kind: str = "expression" ) -> tuple[Subst, Inst]: """Checks a type against another type. Returns a substitution for the free variables the expected type and an instantiation - for the quantified variables in the actual type. + for the quantified variables in the actual type. Note that the expected type may not + be quantified and the actual type may not contain free unification variables. """ - # Expected type may not be quantified assert not isinstance(exp, FunctionType) or not exp.quantified assert not act.free_vars - # However, the actual type may be. In that case, we have to find an instantiation to - # avoid higher-rank types. + # The actual type may be quantified. In that case, we have to find an instantiation + # to avoid higher-rank types. subst: Optional[Subst] if isinstance(act, FunctionType) and act.quantified: unquantified, free_vars = act.unquantified() @@ -362,7 +349,7 @@ def check_type_against( # Check that we have found a valid instantiation for all quantified vars for i, v in enumerate(free_vars): if v.id not in subst: - raise GuppyTypeError( + raise GuppyTypeInferenceError( f"Expected {kind} of type `{exp}`, got `{act}`. Couldn't infer an " f"instantiation for type variable `{act.quantified[i]}` (higher-" "rank polymorphic types are not supported)", @@ -387,6 +374,20 @@ def check_type_against( return subst, [] +def check_num_args(exp: int, act: int, node: AstNode) -> None: + """Checks that the correct number of arguments have been passed to a function.""" + if act < exp: + raise GuppyTypeError( + f"Not enough arguments passed (expected {exp}, got {act})", node + ) + if exp < act: + if isinstance(node, ast.Call): + raise GuppyTypeError("Unexpected argument", node.args[exp]) + raise GuppyTypeError( + f"Too many arguments passed (expected {exp}, got {act})", node + ) + + def type_check_args( args: list[ast.expr], func_ty: FunctionType, @@ -416,7 +417,7 @@ def type_check_args( # We also have to check that we found instantiations for all vars in the return type if not set.issubset(set(func_ty.returns.free_vars.keys()), subst.keys()): - raise GuppyTypeError( + raise GuppyTypeInferenceError( f"Cannot infer type variable in expression of type " f"`{func_ty.returns.substitute(subst)}`", node, @@ -470,24 +471,24 @@ def check_call( # x: int = foo(None) # ^^^^ Expected argument of type `int`, got `None` # - # The following error location would be more intuitive for users: + # But the following error location would be more intuitive for users: # # x: int = foo(None) - # ^^^^^^^^^ Expected expression of type `int`, got `None` + # ^^^^^^^^^ Expected expression of type `int`, got `None` # - # Therefore, we should only resort to using the type information for inference if - # the regular synthesis method doesn't succeed. + # In other words, if we can get away with synthesising the call without the extra + # information from the expected type, we should do that to improve the error. - # TODO: This approach can result in exponential runtime in the worst case. However - # the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common in - # practice + # TODO: The approach below can result in exponential runtime in the worst case. + # However the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common + # in practice. Can we do better than that? # First, try to synthesize res: Optional[tuple[GuppyType, Inst]] = None try: args, synth, inst = synthesize_call(func_ty, args, node, ctx) res = synth, inst - except GuppyTypeError: + except GuppyTypeInferenceError: pass if res is not None: synth, inst = res @@ -496,7 +497,8 @@ def check_call( raise GuppyTypeError(f"Expected {kind} of type `{ty}`, got `{synth}`", node) return args, subst, inst - # Only if synthesis fails, we try to infer more from the return type + # If synthesis fails, we try again, this time also using information from the + # expected return type unquantified, free_vars = func_ty.unquantified() subst = unify(ty, unquantified.returns, {}) if subst is None: @@ -510,7 +512,7 @@ def check_call( # Also make sure we found an instantiation for all free vars in the type we're # checking against if not set.issubset(set(ty.free_vars.keys()), subst.keys()): - raise GuppyTypeError( + raise GuppyTypeInferenceError( f"Expected expression of type `{ty}`, got " f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", node, diff --git a/guppy/error.py b/guppy/error.py index d06b9129..e8b56bde 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -73,6 +73,12 @@ class GuppyTypeError(GuppyError): pass +class GuppyTypeInferenceError(GuppyError): + """Special Guppy exception for type inference errors.""" + + pass + + class InternalGuppyError(Exception): """Exception for internal problems during compilation.""" From 4eba3dc9be088ad52aa3f52f949206857cb74732 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 27 Nov 2023 11:25:48 +0000 Subject: [PATCH 36/77] feat: Add builtins module (#60) --- guppy/custom.py | 5 + guppy/prelude/_internal.py | 249 ++++++++++++- guppy/prelude/builtins.py | 725 +++++++++++++++++++++++++++++++++++++ guppy/prelude/quantum.py | 35 +- 4 files changed, 1011 insertions(+), 3 deletions(-) diff --git a/guppy/custom.py b/guppy/custom.py index 6ed5ae0e..61800095 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -226,3 +226,8 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) return_ty = get_type(self.node) return [node.add_out_port(ty) for ty in type_to_row(return_ty)] + + +class NoopCompiler(CustomCallCompiler): + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + return args diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 66b40574..d3386459 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -1,13 +1,43 @@ -from typing import Literal +import ast +from typing import Optional, Literal from pydantic import BaseModel -from guppy.hugr import val +from guppy.ast_util import with_type, AstNode, with_loc, get_type +from guppy.checker.core import Context, CallableVariable +from guppy.checker.expr_checker import ExprSynthesizer, check_num_args +from guppy.custom import ( + CustomCallChecker, + DefaultCallChecker, + CustomFunction, + CustomCallCompiler, +) +from guppy.error import GuppyTypeError, GuppyError +from guppy.gtypes import GuppyType, FunctionType, BoolType +from guppy.hugr import ops, tys, val +from guppy.hugr.hugr import OutPortV +from guppy.nodes import GlobalCall INT_WIDTH = 6 # 2^6 = 64 bit +hugr_int_type = tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.BoundedNatArg(n=INT_WIDTH)], + bound=tys.TypeBound.Eq, +) + + +hugr_float_type = tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, +) + + class ConstIntS(BaseModel): """Hugr representation of signed integers in the arithmetic extension.""" @@ -36,3 +66,218 @@ def int_value(i: int) -> val.Value: def float_value(f: float) -> val.Value: """Returns the Hugr representation of a float value.""" return val.Prim(val=val.ExtensionVal(c=(ConstF64(value=f),))) + + +def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops.OpType: + """Utility method to create Hugr logic ops.""" + return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) + + +def int_op( + op_name: str, ext: str = "arithmetic.int", num_params: int = 1 +) -> ops.OpType: + """Utility method to create Hugr integer arithmetic ops.""" + return ops.CustomOp( + extension=ext, + op_name=op_name, + args=num_params * [tys.BoundedNatArg(n=INT_WIDTH)], + ) + + +def float_op(op_name: str, ext: str = "arithmetic.float") -> ops.OpType: + """Utility method to create Hugr integer arithmetic ops.""" + return ops.CustomOp(extension=ext, op_name=op_name, args=[]) + + +class CoercingChecker(DefaultCallChecker): + """Function call type checker that automatically coerces arguments to float.""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + from .builtins import Int + + for i in range(len(args)): + args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) + if isinstance(ty, self.ctx.globals.types["int"]): + call = with_loc( + self.node, GlobalCall(func=Int.__float__, args=[args[i]]) + ) + args[i] = with_type(self.ctx.globals.types["float"].build(), call) + return super().synthesize(args) + + +class ReversingChecker(CustomCallChecker): + """Call checker that reverses the arguments after checking.""" + + base_checker: CustomCallChecker + + def __init__(self, base_checker: Optional[CustomCallChecker] = None): + self.base_checker = base_checker or DefaultCallChecker() + + def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: + super()._setup(ctx, node, func) + self.base_checker._setup(ctx, node, func) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + expr = self.base_checker.check(args, ty) + if isinstance(expr, GlobalCall): + expr.args = list(reversed(args)) + return expr + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + expr, ty = self.base_checker.synthesize(args) + if isinstance(expr, GlobalCall): + expr.args = list(reversed(args)) + return expr, ty + + +class UnsupportedChecker(CustomCallChecker): + """Call checker for Python builtin functions that are not available in Guppy. + + Gives the uses a nicer error message when they try to use an unsupported feature. + """ + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + raise GuppyError( + f"Builtin method `{self.func.name}` is not supported by Guppy", self.node + ) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + raise GuppyError( + f"Builtin method `{self.func.name}` is not supported by Guppy", self.node + ) + + +class DunderChecker(CustomCallChecker): + """Call checker for builtin functions that call out to dunder instance methods""" + + dunder_name: str + num_args: int + + def __init__(self, dunder_name: str, num_args: int = 1): + assert num_args > 0 + self.dunder_name = dunder_name + self.num_args = num_args + + def _get_func( + self, args: list[ast.expr] + ) -> tuple[list[ast.expr], CallableVariable]: + check_num_args(self.num_args, len(args), self.node) + fst, *rest = args + fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) + func = self.ctx.globals.get_instance_func(ty, self.dunder_name) + if func is None: + raise GuppyTypeError( + f"Builtin function `{self.func.name}` is not defined for argument of " + f"type `{ty}`", + self.node.args[0] if isinstance(self.node, ast.Call) else self.node, + ) + return [fst, *rest], func + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + args, func = self._get_func(args) + return func.synthesize_call(args, self.node, self.ctx) + + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + args, func = self._get_func(args) + return func.check_call(args, ty, self.node, self.ctx) + + +class CallableChecker(CustomCallChecker): + """Call checker for the builtin `callable` function""" + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + check_num_args(1, len(args), self.node) + [arg] = args + arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) + is_callable = ( + isinstance(ty, FunctionType) + or self.ctx.globals.get_instance_func(ty, "__call__") is not None + ) + const = with_loc(self.node, ast.Constant(value=is_callable)) + return const, BoolType() + + +class IntTruedivCompiler(CustomCallCompiler): + """Compiler for the `int.__truediv__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Int, Float + + # Compile `truediv` using float arithmetic + [left, right] = args + [left] = Int.__float__.compile_call( + [left], self.dfg, self.graph, self.globals, self.node + ) + [right] = Int.__float__.compile_call( + [right], self.dfg, self.graph, self.globals, self.node + ) + return Float.__truediv__.compile_call( + [left, right], self.dfg, self.graph, self.globals, self.node + ) + + +class FloatBoolCompiler(CustomCallCompiler): + """Compiler for the `float.__bool__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: bool(x) = (x != 0.0) + zero_const = self.graph.add_constant( + float_value(0.0), get_type(self.node), self.dfg.node + ) + zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) + return Float.__ne__.compile_call( + [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node + ) + + +class FloatFloordivCompiler(CustomCallCompiler): + """Compiler for the `float.__floordiv__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: floordiv(x, y) = floor(truediv(x, y)) + [div] = Float.__truediv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [floor] = Float.__floor__.compile_call( + [div], self.dfg, self.graph, self.globals, self.node + ) + return [floor] + + +class FloatModCompiler(CustomCallCompiler): + """Compiler for the `float.__mod__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: mod(x, y) = x - (x // y) * y + [div] = Float.__floordiv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [mul] = Float.__mul__.compile_call( + [div, args[1]], self.dfg, self.graph, self.globals, self.node + ) + [sub] = Float.__sub__.compile_call( + [args[0], mul], self.dfg, self.graph, self.globals, self.node + ) + return [sub] + + +class FloatDivmodCompiler(CustomCallCompiler): + """Compiler for the `__divmod__` method.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + from .builtins import Float + + # We have: divmod(x, y) = (div(x, y), mod(x, y)) + [div] = Float.__truediv__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + [mod] = Float.__mod__.compile_call( + args, self.dfg, self.graph, self.globals, self.node + ) + return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index 98358bc9..25dfc741 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -1,6 +1,731 @@ """Guppy module for builtin types and operations.""" +# mypy: disable-error-code="empty-body, misc, override, no-untyped-def" + +from guppy.custom import NoopCompiler, DefaultCallChecker +from guppy.decorator import guppy +from guppy.gtypes import BoolType +from guppy.hugr import tys, ops from guppy.module import GuppyModule +from guppy.prelude._internal import ( + logic_op, + int_op, + hugr_int_type, + hugr_float_type, + float_op, + CoercingChecker, + ReversingChecker, + IntTruedivCompiler, + FloatBoolCompiler, + FloatDivmodCompiler, + FloatFloordivCompiler, + FloatModCompiler, + DunderChecker, + CallableChecker, + UnsupportedChecker, +) builtins = GuppyModule("builtins", import_builtins=False) + + +@guppy.extend_type(builtins, BoolType) +class Bool: + @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) + def __and__(self: bool, other: bool) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __bool__(self: bool) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ifrombool")) + def __int__(self: bool) -> int: + ... + + @guppy.hugr_op(builtins, logic_op("Or", [tys.BoundedNatArg(n=2)])) + def __or__(self: bool, other: bool) -> bool: + ... + + +@guppy.type(builtins, hugr_int_type, name="int") +class Int: + @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) + def __abs__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iadd")) + def __add__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iand")) + def __and__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("itobool")) + def __bool__(self: int) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __ceil__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2)) + def __divmod__(self: int, other: int) -> tuple[int, int]: + ... + + @guppy.hugr_op(builtins, int_op("ieq")) + def __eq__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("convert_s", "arithmetic.conversions")) + def __float__(self: int) -> float: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __floor__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2)) + def __floordiv__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ige_s")) + def __ge__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("igt_s")) + def __gt__(self: int, other: int) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __int__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("inot")) + def __invert__(self: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ile_s")) + def __le__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ishl", num_params=2)) # TODO: RHS is unsigned + def __lshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ilt_s")) + def __lt__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2)) + def __mod__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imul")) + def __mul__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ine")) + def __ne__(self: int, other: int) -> bool: + ... + + @guppy.hugr_op(builtins, int_op("ineg")) + def __neg__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ior")) + def __or__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __pos__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="ipow")) # TODO + def __pow__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker()) + def __radd__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("rand"), ReversingChecker()) + def __rand__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker()) + def __rdivmod__(self: int, other: int) -> tuple[int, int]: + ... + + @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker()) + def __rfloordiv__(self: int, other: int) -> int: + ... + + @guppy.hugr_op( + builtins, int_op("ishl", num_params=2), ReversingChecker() + ) # TODO: RHS is unsigned + def __rlshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker()) + def __rmod__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("imul"), ReversingChecker()) + def __rmul__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ior"), ReversingChecker()) + def __ror__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __round__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="ipow"), ReversingChecker()) # TODO + def __rpow__(self: int, other: int) -> int: + ... + + @guppy.hugr_op( + builtins, int_op("ishr", num_params=2), ReversingChecker() + ) # TODO: RHS is unsigned + def __rrshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ishr", num_params=2)) # TODO: RHS is unsigned + def __rshift__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("isub"), ReversingChecker()) + def __rsub__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker()) + def __rtruediv__(self: int, other: int) -> float: + ... + + @guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker()) + def __rxor__(self: int, other: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("isub")) + def __sub__(self: int, other: int) -> int: + ... + + @guppy.custom(builtins, IntTruedivCompiler()) + def __truediv__(self: int, other: int) -> float: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __trunc__(self: int) -> int: + ... + + @guppy.hugr_op(builtins, int_op("ixor")) + def __xor__(self: int, other: int) -> int: + ... + + +@guppy.type(builtins, hugr_float_type, name="float") +class Float: + @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) + def __abs__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fadd"), CoercingChecker()) + def __add__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatBoolCompiler(), CoercingChecker()) + def __bool__(self: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fceil"), CoercingChecker()) + def __ceil__(self: float) -> float: + ... + + @guppy.custom(builtins, FloatDivmodCompiler(), CoercingChecker()) + def __divmod__(self: float, other: float) -> tuple[float, float]: + ... + + @guppy.hugr_op(builtins, float_op("feq"), CoercingChecker()) + def __eq__(self: float, other: float) -> bool: + ... + + @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) + def __float__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("ffloor"), CoercingChecker()) + def __floor__(self: float) -> float: + ... + + @guppy.custom(builtins, FloatFloordivCompiler(), CoercingChecker()) + def __floordiv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fge"), CoercingChecker()) + def __ge__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fgt"), CoercingChecker()) + def __gt__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) + def __int__(self: float) -> int: + ... + + @guppy.hugr_op(builtins, float_op("fle"), CoercingChecker()) + def __le__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("flt"), CoercingChecker()) + def __lt__(self: float, other: float) -> bool: + ... + + @guppy.custom(builtins, FloatModCompiler(), CoercingChecker()) + def __mod__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fmul"), CoercingChecker()) + def __mul__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fne"), CoercingChecker()) + def __ne__(self: float, other: float) -> bool: + ... + + @guppy.hugr_op(builtins, float_op("fneg"), CoercingChecker()) + def __neg__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, NoopCompiler(), CoercingChecker()) + def __pos__(self: float) -> float: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="fpow")) # TODO + def __pow__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fadd"), ReversingChecker(CoercingChecker())) + def __radd__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatDivmodCompiler(), ReversingChecker(CoercingChecker())) + def __rdivmod__(self: float, other: float) -> tuple[float, float]: + ... + + @guppy.custom( + builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker()) + ) + def __rfloordiv__(self: float, other: float) -> float: + ... + + @guppy.custom(builtins, FloatModCompiler(), ReversingChecker(CoercingChecker())) + def __rmod__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fmul"), ReversingChecker(CoercingChecker())) + def __rmul__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="fround")) # TODO + def __round__(self: float) -> float: + ... + + @guppy.hugr_op( + builtins, ops.DummyOp(name="fpow"), ReversingChecker(DefaultCallChecker()) + ) # TODO + def __rpow__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fsub"), ReversingChecker(CoercingChecker())) + def __rsub__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fdiv"), ReversingChecker(CoercingChecker())) + def __rtruediv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fsub"), CoercingChecker()) + def __sub__(self: float, other: float) -> float: + ... + + @guppy.hugr_op(builtins, float_op("fdiv"), CoercingChecker()) + def __truediv__(self: float, other: float) -> float: + ... + + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) + def __trunc__(self: float) -> float: + ... + + +@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False) +def abs(x): + ... + + +@guppy.custom( + builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False +) +def _bool(x): + ... + + +@guppy.custom(builtins, checker=CallableChecker(), higher_order_value=False) +def callable(x): + ... + + +@guppy.custom( + builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False +) +def divmod(x, y): + ... + + +@guppy.custom( + builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False +) +def _float(x, y): + ... + + +@guppy.custom( + builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False +) +def _int(x): + ... + + +@guppy.custom( + builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False +) +def pow(x, y): + ... + + +@guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) +def round(x): + ... + + +# Python builtins that are not supported yet: + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def aiter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def all(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def anext(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def any(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bin(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def breakpoint(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bytearray(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def bytes(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def chr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def classmethod(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def compile(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def complex(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def delattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def dict(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def dir(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def enumerate(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def eval(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def exec(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def filter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def format(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def forozenset(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def getattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def globals(): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hasattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hash(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def help(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def hex(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def id(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def input(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def isinstance(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def issubclass(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def iter(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def len(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def list(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def locals(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def map(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def max(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def memoryview(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def min(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def next(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def object(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def oct(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def open(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def ord(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def print(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def property(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def range(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def repr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def reversed(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def set(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def setattr(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def slice(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def sorted(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def staticmethod(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def str(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def sum(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def super(x): + ... + + +@guppy.custom( + builtins, name="tuple", checker=UnsupportedChecker(), higher_order_value=False +) +def _tuple(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def type(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def vars(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def zip(x): + ... + + +@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) +def __import__(x): + ... diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index e1ef0f08..3b40af88 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -3,7 +3,7 @@ # mypy: disable-error-code=empty-body from guppy.decorator import guppy -from guppy.hugr import tys +from guppy.hugr import tys, ops from guppy.hugr.tys import TypeBound from guppy.module import GuppyModule @@ -11,6 +11,11 @@ quantum = GuppyModule("quantum") +def quantum_op(op_name: str) -> ops.OpType: + """Utility method to create Hugr quantum ops.""" + return ops.CustomOp(extension="quantum.tket2", op_name=op_name, args=[]) + + @guppy.type( quantum, tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), @@ -20,13 +25,41 @@ class Qubit: pass +@guppy.hugr_op(quantum, quantum_op("H")) def h(q: Qubit) -> Qubit: ... +@guppy.hugr_op(quantum, quantum_op("CX")) def cx(control: Qubit, target: Qubit) -> tuple[Qubit, Qubit]: ... +@guppy.hugr_op(quantum, quantum_op("RzF64")) +def rz(q: Qubit, angle: float) -> Qubit: + ... + + +@guppy.hugr_op(quantum, quantum_op("Measure")) def measure(q: Qubit) -> tuple[Qubit, bool]: ... + + +@guppy.hugr_op(quantum, quantum_op("T")) +def t(q: Qubit) -> Qubit: + ... + + +@guppy.hugr_op(quantum, quantum_op("Tdg")) +def tdg(q: Qubit) -> Qubit: + ... + + +@guppy.hugr_op(quantum, quantum_op("Z")) +def z(q: Qubit) -> Qubit: + ... + + +@guppy.hugr_op(quantum, quantum_op("X")) +def x(q: Qubit) -> Qubit: + ... From d6982d502fbc1c43842fd8109202c05b97a88d92 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 11:28:56 +0000 Subject: [PATCH 37/77] Fix mypy issues --- guppy/cfg/bb.py | 5 +++-- guppy/cfg/builder.py | 8 +++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index eb6e0ff9..2d414829 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -34,7 +34,7 @@ def update_used(self, node: ast.AST) -> None: self.used[name.id] = name -BBStatement = Union[ast.Assign, ast.AugAssign, ast.Expr, ast.Return, NestedFunctionDef] +BBStatement = Union[ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef] @dataclass(eq=False) # Disable equality to recover hash from `object` @@ -104,7 +104,8 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None: self.stats.assigned[name.id] = node def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - self.stats.update_used(node.value) + if node.value: + self.stats.update_used(node.value) for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 71f1da1c..5607f3c2 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -3,7 +3,7 @@ from typing import Optional, Iterator, Union, NamedTuple from guppy.ast_util import set_location_from, AstVisitor -from guppy.cfg.bb import BB +from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError @@ -75,15 +75,13 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[B bb_opt = self.visit(node, bb_opt, jumps) return bb_opt - def _build_node_value( - self, node: Union[ast.Assign, ast.AugAssign, ast.Return, ast.Expr], bb: BB - ) -> BB: + def _build_node_value(self, node: BBStatement, bb: BB) -> BB: """Utility method for building a node containing a `value` expression. Builds the expression and mutates `node.value` to point to the built expression. Returns the BB in which the expression is available and adds the node to it. """ - if node.value is not None: + if not isinstance(node, NestedFunctionDef) and node.value is not None: node.value, bb = ExprBuilder.build(node.value, self.cfg, bb) bb.statements.append(node) return bb From 3e89ac8ca6aee407b5fef2a5aa7d0d756aac25ca Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 11:29:13 +0000 Subject: [PATCH 38/77] Run formatting --- guppy/cfg/bb.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index 2d414829..f5831dbf 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -34,7 +34,9 @@ def update_used(self, node: ast.AST) -> None: self.used[name.id] = name -BBStatement = Union[ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef] +BBStatement = Union[ + ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef +] @dataclass(eq=False) # Disable equality to recover hash from `object` From b9c81efa6ece27122006c590a70bf3cbc1398029 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 11:35:59 +0000 Subject: [PATCH 39/77] Add missing check method --- guppy/prelude/_internal.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index d3386459..d27fceb1 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -196,6 +196,14 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: const = with_loc(self.node, ast.Constant(value=is_callable)) return const, BoolType() + def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + args, _ = self.synthesize(args) + if not isinstance(ty, BoolType): + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `bool`", self.node + ) + return args + class IntTruedivCompiler(CustomCallCompiler): """Compiler for the `int.__truediv__` method.""" From 2aed4bfbde6ee0a0c6b55d0d7c1d6b3767176585 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 11:45:43 +0000 Subject: [PATCH 40/77] Add some tests --- tests/error/poly_errors/__init__.py | 0 tests/error/poly_errors/arg_mismatch1.err | 7 +++++++ tests/error/poly_errors/arg_mismatch1.py | 20 +++++++++++++++++++ tests/error/poly_errors/arg_mismatch2.err | 7 +++++++ tests/error/poly_errors/arg_mismatch2.py | 20 +++++++++++++++++++ tests/error/poly_errors/free_return_var.err | 7 +++++++ tests/error/poly_errors/free_return_var.py | 20 +++++++++++++++++++ .../poly_errors/inst_return_mismatch.err | 7 +++++++ .../error/poly_errors/inst_return_mismatch.py | 20 +++++++++++++++++++ .../inst_return_mismatch_nested.err | 7 +++++++ .../inst_return_mismatch_nested.py | 20 +++++++++++++++++++ tests/error/poly_errors/return_mismatch.err | 7 +++++++ tests/error/poly_errors/return_mismatch.py | 20 +++++++++++++++++++ 13 files changed, 162 insertions(+) create mode 100644 tests/error/poly_errors/__init__.py create mode 100644 tests/error/poly_errors/arg_mismatch1.err create mode 100644 tests/error/poly_errors/arg_mismatch1.py create mode 100644 tests/error/poly_errors/arg_mismatch2.err create mode 100644 tests/error/poly_errors/arg_mismatch2.py create mode 100644 tests/error/poly_errors/free_return_var.err create mode 100644 tests/error/poly_errors/free_return_var.py create mode 100644 tests/error/poly_errors/inst_return_mismatch.err create mode 100644 tests/error/poly_errors/inst_return_mismatch.py create mode 100644 tests/error/poly_errors/inst_return_mismatch_nested.err create mode 100644 tests/error/poly_errors/inst_return_mismatch_nested.py create mode 100644 tests/error/poly_errors/return_mismatch.err create mode 100644 tests/error/poly_errors/return_mismatch.py diff --git a/tests/error/poly_errors/__init__.py b/tests/error/poly_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/poly_errors/arg_mismatch1.err b/tests/error/poly_errors/arg_mismatch1.err new file mode 100644 index 00000000..211407eb --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: bool, y: tuple[bool]) -> None: +17: foo(x, y) + ^ +GuppyTypeError: Expected argument of type `bool`, got `(bool)` diff --git a/tests/error/poly_errors/arg_mismatch1.py b/tests/error/poly_errors/arg_mismatch1.py new file mode 100644 index 00000000..0e631f22 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch1.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T, y: T) -> None: + ... + + +@guppy(module) +def main(x: bool, y: tuple[bool]) -> None: + foo(x, y) + + +module.compile() diff --git a/tests/error/poly_errors/arg_mismatch2.err b/tests/error/poly_errors/arg_mismatch2.err new file mode 100644 index 00000000..13b54199 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main() -> None: +17: foo(False) + ^^^^^ +GuppyTypeError: Expected argument of type `(?T, ?T)`, got `bool` diff --git a/tests/error/poly_errors/arg_mismatch2.py b/tests/error/poly_errors/arg_mismatch2.py new file mode 100644 index 00000000..78f82b0c --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch2.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: tuple[T, T]) -> None: + ... + + +@guppy(module) +def main() -> None: + foo(False) + + +module.compile() diff --git a/tests/error/poly_errors/free_return_var.err b/tests/error/poly_errors/free_return_var.err new file mode 100644 index 00000000..bd5e522c --- /dev/null +++ b/tests/error/poly_errors/free_return_var.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main() -> None: +17: x = foo() + ^^^^^ +GuppyTypeInferenceError: Cannot infer type variable in expression of type `?T` diff --git a/tests/error/poly_errors/free_return_var.py b/tests/error/poly_errors/free_return_var.py new file mode 100644 index 00000000..9ac9299e --- /dev/null +++ b/tests/error/poly_errors/free_return_var.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo() -> T: + ... + + +@guppy(module) +def main() -> None: + x = foo() + + +module.compile() diff --git a/tests/error/poly_errors/inst_return_mismatch.err b/tests/error/poly_errors/inst_return_mismatch.err new file mode 100644 index 00000000..90683da3 --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: bool) -> None: +17: y: None = foo(x) + ^^^^^^ +GuppyTypeError: Expected expression of type `None`, got `bool` diff --git a/tests/error/poly_errors/inst_return_mismatch.py b/tests/error/poly_errors/inst_return_mismatch.py new file mode 100644 index 00000000..885de5fa --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T) -> T: + ... + + +@guppy(module) +def main(x: bool) -> None: + y: None = foo(x) + + +module.compile() diff --git a/tests/error/poly_errors/inst_return_mismatch_nested.err b/tests/error/poly_errors/inst_return_mismatch_nested.err new file mode 100644 index 00000000..1c14fb58 --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch_nested.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: bool) -> None: +17: y: None = foo(foo(foo(x))) + ^^^^^^^^^^^^^^^^ +GuppyTypeError: Expected expression of type `None`, got `bool` diff --git a/tests/error/poly_errors/inst_return_mismatch_nested.py b/tests/error/poly_errors/inst_return_mismatch_nested.py new file mode 100644 index 00000000..b308f31f --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch_nested.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T) -> T: + ... + + +@guppy(module) +def main(x: bool) -> None: + y: None = foo(foo(foo(x))) + + +module.compile() diff --git a/tests/error/poly_errors/return_mismatch.err b/tests/error/poly_errors/return_mismatch.err new file mode 100644 index 00000000..4d3029b2 --- /dev/null +++ b/tests/error/poly_errors/return_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main() -> None: +17: x: bool = foo() + ^^^^^ +GuppyTypeError: Expected expression of type `bool`, got `(?T, ?T)` diff --git a/tests/error/poly_errors/return_mismatch.py b/tests/error/poly_errors/return_mismatch.py new file mode 100644 index 00000000..7be60db6 --- /dev/null +++ b/tests/error/poly_errors/return_mismatch.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo() -> tuple[T, T]: + ... + + +@guppy(module) +def main() -> None: + x: bool = foo() + + +module.compile() From a39389ca535e0ce663d96123ef03021ff985e630 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 15:20:50 +0000 Subject: [PATCH 41/77] Fix add_input_with_ports --- guppy/hugr/hugr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 59a00e97..96a647f1 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -379,7 +379,7 @@ def add_input_with_ports( self, output_tys: Sequence[GuppyType], parent: Optional[Node] = None ) -> tuple[VNode, list[OutPortV]]: """Adds an `Input` node to the graph.""" - node = self.add_input(list(output_tys), parent) + node = self.add_input(None, parent) ports = [node.add_out_port(ty) for ty in output_tys] return node, ports From 8aea545c3f8ec813d9a2531bd80fe574b3369839 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 15:21:56 +0000 Subject: [PATCH 42/77] Remove unused import --- guppy/cfg/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 5607f3c2..e937e144 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,6 +1,6 @@ import ast import itertools -from typing import Optional, Iterator, Union, NamedTuple +from typing import Optional, Iterator, NamedTuple from guppy.ast_util import set_location_from, AstVisitor from guppy.cfg.bb import BB, BBStatement From b2ad47843d62f2ca4427da5baa8bd7a2077ba51a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 15:23:55 +0000 Subject: [PATCH 43/77] Remove unused import --- guppy/custom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guppy/custom.py b/guppy/custom.py index 3263a1af..a80fd134 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -11,7 +11,6 @@ GuppyError, InternalGuppyError, UnknownFunctionType, - GuppyTypeError, ) from guppy.gtypes import GuppyType, FunctionType, type_to_row from guppy.hugr import ops From c3b1228c0c9b895bd156231ac62597804f5ae680 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 15:38:37 +0000 Subject: [PATCH 44/77] Fix small issues --- guppy/gtypes.py | 8 +------- tests/error/misc_errors/return_not_annotated.err | 2 +- tests/error/misc_errors/return_not_annotated_none1.err | 2 +- tests/error/misc_errors/return_not_annotated_none2.err | 2 +- tests/hugr/test_dummy_nodes.py | 2 +- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/guppy/gtypes.py b/guppy/gtypes.py index d306c514..02f8d718 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -253,9 +253,6 @@ def linear(self) -> bool: def type_args(self) -> Iterator[GuppyType]: return iter(self.element_types) - def substitute(self, s: Subst) -> GuppyType: - return TupleType([ty.substitute(s) for ty in self.element_types]) - def to_hugr(self) -> tys.SimpleType: ts = [t.to_hugr() for t in self.element_types] return tys.Tuple(inner=ts) @@ -289,9 +286,6 @@ def linear(self) -> bool: def type_args(self) -> Iterator[GuppyType]: return iter(self.element_types) - def substitute(self, s: Subst) -> GuppyType: - return TupleType([ty.substitute(s) for ty in self.element_types]) - def to_hugr(self) -> tys.SimpleType: if all( isinstance(e, TupleType) and len(e.element_types) == 0 @@ -301,7 +295,7 @@ def to_hugr(self) -> tys.SimpleType: return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or TupleType( + return transformer.transform(self) or SumType( [ty.transform(transformer) for ty in self.element_types] ) diff --git a/tests/error/misc_errors/return_not_annotated.err b/tests/error/misc_errors/return_not_annotated.err index 0c2eb656..879de215 100644 --- a/tests/error/misc_errors/return_not_annotated.err +++ b/tests/error/misc_errors/return_not_annotated.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(x: bool): - ^^^^^^^^^^^^ + ^^^^^^^^^^^^^^^^^^ GuppyError: Return type must be annotated diff --git a/tests/error/misc_errors/return_not_annotated_none1.err b/tests/error/misc_errors/return_not_annotated_none1.err index 58d79a43..82f22409 100644 --- a/tests/error/misc_errors/return_not_annotated_none1.err +++ b/tests/error/misc_errors/return_not_annotated_none1.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(): - ^^^^^^^^^^ + ^^^^^^^^^^^ GuppyError: Return type must be annotated. Try adding a `-> None` annotation. diff --git a/tests/error/misc_errors/return_not_annotated_none2.err b/tests/error/misc_errors/return_not_annotated_none2.err index cebd15a9..82f22409 100644 --- a/tests/error/misc_errors/return_not_annotated_none2.err +++ b/tests/error/misc_errors/return_not_annotated_none2.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(): - ^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^ GuppyError: Return type must be annotated. Try adding a `-> None` annotation. diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index e1efaab9..97c472b5 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -5,7 +5,7 @@ def test_single_dummy(): g = Hugr() - defn = g.add_def(FunctionType([BoolType()], [BoolType()]), g.root, "test") + defn = g.add_def(FunctionType([BoolType()], BoolType()), g.root, "test") dfg = g.add_dfg(defn) inp = g.add_input([BoolType()], dfg).out_port(0) dummy = g.add_node( From a83fff5d996b7a76d784d83fa107cf1680eeb694 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 16:27:46 +0000 Subject: [PATCH 45/77] Make everything work with Hugr --- guppy/checker/expr_checker.py | 29 ++++++++++++--------- guppy/compiler/core.py | 3 ++- guppy/compiler/expr_compiler.py | 4 ++- guppy/compiler/func_compiler.py | 12 +++++++-- guppy/custom.py | 22 +++++++++++++--- guppy/declared.py | 14 +++++++--- guppy/gtypes.py | 15 +++++++---- guppy/hugr/hugr.py | 16 ++++++++---- guppy/hugr/ops.py | 18 +++++++++++++ guppy/hugr/tys.py | 11 +++++--- guppy/prelude/_internal.py | 46 +++++++++++++++++++-------------- 11 files changed, 134 insertions(+), 56 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 9dc52d86..4c80d5df 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -3,8 +3,12 @@ from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt from guppy.checker.core import Context, CallableVariable, Globals -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError, \ - GuppyTypeInferenceError +from guppy.error import ( + GuppyError, + GuppyTypeError, + InternalGuppyError, + GuppyTypeInferenceError, +) from guppy.gtypes import ( GuppyType, TupleType, @@ -129,11 +133,7 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: # Otherwise, it must be a function as a higher-order value if isinstance(func_ty, FunctionType): args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx) - # Maybe we have to add a TypeApply node - if len(inst) > 0: - func = TypeApply(value=with_type(ty, node.func), tys=inst) - func = with_type(func_ty.instantiate(inst), func) - node.func = with_loc(node.func, func) + node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(func_ty, "__call__"): return f.check_call(node.args, ty, node, self.ctx) @@ -296,11 +296,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: # Otherwise, it must be a function as a higher-order value if isinstance(ty, FunctionType): args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) - # Maybe we have to add a TypeApply node - if len(inst) > 0: - func = TypeApply(value=with_type(ty, node.func), tys=inst) - func = with_type(ty.instantiate(inst), func) - node.func = with_loc(node.func, func) + node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) @@ -525,6 +521,15 @@ def check_call( return args, subst, inst +def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: + """Instantiates quantified type arguments in a function.""" + assert len(ty.quantified) == len(inst) + if len(inst) > 0: + node = with_loc(node, TypeApply(value=with_type(ty, node), tys=inst)) + return with_type(ty.instantiate(inst), node) + return node + + def to_bool( node: ast.expr, node_ty: GuppyType, ctx: Context ) -> tuple[ast.expr, GuppyType]: diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 2b6a4fb2..2d4c9cdd 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstNode from guppy.checker.core import Variable, CallableVariable -from guppy.gtypes import FunctionType +from guppy.gtypes import FunctionType, GuppyType, Inst from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr @@ -47,6 +47,7 @@ class CompiledFunction(CompiledVariable, CallableVariable): def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 3dac0690..5079b1dd 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -71,7 +71,9 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: assert isinstance(func, CompiledFunction) args = [self.visit(arg) for arg in node.args] - rets = func.compile_call(args, self.dfg, self.graph, self.globals, node) + rets = func.compile_call( + args, list(node.type_args), self.dfg, self.graph, self.globals, node + ) return self._pack_returns(rets) def visit_Call(self, node: ast.Call) -> OutPortV: diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index 4631f0db..b8d355cb 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -9,7 +9,7 @@ DFContainer, PortVariable, ) -from guppy.gtypes import type_to_row, FunctionType +from guppy.gtypes import type_to_row, FunctionType, Inst from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef @@ -26,12 +26,20 @@ def load( def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - call = graph.add_call(self.node.out_port(0), args, dfg.node) + # TODO: Hugr should probably allow us to pass type args to `Call`, so we can + # avoid loading the function to manually add a `TypeApply` + if type_args: + func = graph.add_load_constant(self.node.out_port(0), dfg.node) + func = graph.add_type_apply(func.out_port(0), type_args, dfg.node) + call = graph.add_indirect_call(func.out_port(0), args, dfg.node) + else: + call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/custom.py b/guppy/custom.py index af1dd866..d982f4b0 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -12,7 +12,7 @@ InternalGuppyError, UnknownFunctionType, ) -from guppy.gtypes import GuppyType, FunctionType, Subst, type_to_row +from guppy.gtypes import GuppyType, FunctionType, Subst, type_to_row, Inst from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode from guppy.nodes import GlobalCall @@ -100,12 +100,13 @@ def synthesize_call( def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - self.call_compiler._setup(dfg, graph, globals, node) + self.call_compiler._setup(type_args, dfg, graph, globals, node) return self.call_compiler.compile(args) def load( @@ -123,6 +124,12 @@ def load( node, ) + if self._ty.quantified: + raise InternalGuppyError( + "Can't yet generate higher-order versions of custom functions. This " + "requires generic function *definitions*" + ) + # Find the module node by walking up the hierarchy module: Node = dfg.node while not isinstance(module.op, ops.Module): @@ -139,7 +146,7 @@ def load( def_node = graph.add_def(self.ty, module, self.name) _, inp_ports = graph.add_input_with_ports(list(self.ty.args), def_node) returns = self.compile_call( - inp_ports, DFContainer(def_node, {}), graph, globals, node + inp_ports, [], DFContainer(def_node, {}), graph, globals, node ) graph.add_output(returns, parent=def_node) self._defined[module] = def_node @@ -180,14 +187,21 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: class CustomCallCompiler(ABC): """Protocol for custom function call compilers.""" + type_args: Inst dfg: DFContainer graph: Hugr globals: CompiledGlobals node: AstNode def _setup( - self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + self, + type_args: Inst, + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, ) -> None: + self.type_args = type_args self.dfg = dfg self.graph = graph self.globals = globals diff --git a/guppy/declared.py b/guppy/declared.py index 29bda97e..72002744 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -4,11 +4,11 @@ from guppy.ast_util import AstNode, has_empty_body from guppy.checker.core import Globals, Context -from guppy.checker.expr_checker import check_call, synthesize_call +from guppy.checker.expr_checker import check_call, synthesize_call, instantiate_poly from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError -from guppy.gtypes import type_to_row, GuppyType, Subst +from guppy.gtypes import type_to_row, GuppyType, Subst, Inst from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV from guppy.nodes import GlobalCall @@ -56,11 +56,19 @@ def load( def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: assert self.node is not None - call = graph.add_call(self.node.out_port(0), args, dfg.node) + # TODO: Hugr should probably allow us to pass type args to `Call`, so we can + # avoid loading the function to manually add a `TypeApply` + if type_args: + func = graph.add_load_constant(self.node.out_port(0), dfg.node) + func = graph.add_type_apply(func.out_port(0), type_args, dfg.node) + call = graph.add_indirect_call(func.out_port(0), args, dfg.node) + else: + call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 02f8d718..7612c3ee 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -63,7 +63,7 @@ def __post_init__(self) -> None: if isinstance(self, FreeTypeVar): vs = {self.id: self} else: - vs: dict[TypeVarId, FreeTypeVar] = {} + vs = {} for arg in self.type_args: vs |= arg.free_vars object.__setattr__(self, "_free_vars", vs) @@ -124,8 +124,7 @@ def __str__(self) -> str: return self.display_name def to_hugr(self) -> tys.SimpleType: - # TODO - return NoneType().to_hugr() + return tys.Variable(i=self.idx, b=tys.TypeBound.from_linear(self.linear)) @dataclass(frozen=True) @@ -195,11 +194,17 @@ def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: def type_args(self) -> Iterator[GuppyType]: return itertools.chain(iter(self.args), iter((self.returns,))) - def to_hugr(self) -> tys.SimpleType: + def to_hugr(self) -> tys.PolyFuncType: ins = [t.to_hugr() for t in self.args] outs = [t.to_hugr() for t in type_to_row(self.returns)] func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) - return tys.PolyFuncType(params=[], body=func_ty) + return tys.PolyFuncType( + params=[ + tys.TypeParam(b=tys.TypeBound.from_linear(v.linear)) + for v in self.quantified + ], + body=func_ty, + ) def transform(self, transformer: "TypeTransformer") -> GuppyType: return transformer.transform(self) or FunctionType( diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index eed9b84e..78ad0d37 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -17,7 +17,7 @@ row_to_type, Inst, ) -from guppy.hugr import val +from guppy.hugr import val, tys NodeIdx = int PortOffset = int @@ -533,15 +533,21 @@ def add_partial( ) def add_type_apply( - self, func_port: OutPortV, tys: Inst, parent: Optional[Node] = None + self, func_port: OutPortV, args: Inst, parent: Optional[Node] = None ) -> VNode: """Adds a `TypeApply` node to the graph.""" assert isinstance(func_port.ty, FunctionType) - assert len(func_port.ty.quantified) == len(tys) + assert len(func_port.ty.quantified) == len(args) + result_ty = func_port.ty.instantiate(args) + ta = ops.TypeApplication( + input=func_port.ty.to_hugr(), + args=[tys.TypeArg(ty=ty.to_hugr()) for ty in args], + output=result_ty.to_hugr(), + ) return self.add_node( - ops.DummyOp(name="TypeApply"), + ops.TypeApply(ta=ta), inputs=[func_port], - output_types=[func_port.ty.instantiate(tys)], + output_types=[result_ty], parent=parent, ) diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index 81a3fb7c..c270189d 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -2,6 +2,7 @@ import sys from abc import ABC from typing import Annotated, Literal, Union, Optional, Any + from pydantic import Field, BaseModel from .tys import ( @@ -499,6 +500,22 @@ class Tag(LeafOp): variants: TypeRow # The variants of the sum type. +class TypeApply(LeafOp): + """Fixes some TypeParams of a polymorphic type by providing TypeArgs""" + + lop: Literal["TypeApply"] = "TypeApply" + ta: "TypeApplication" + + +class TypeApplication(BaseModel): + """Records details of an application of a PolyFuncType to some TypeArgs and the + result (a less-, but still potentially-, polymorphic type).""" + + input: PolyFuncType + args: list[tys.TypeArg] + output: PolyFuncType + + LeafOpUnion = Annotated[ Union[ CustomOp, @@ -521,6 +538,7 @@ class Tag(LeafOp): UnpackTuple, MakeNewType, Tag, + TypeApply, ], Field(discriminator="lop"), ] diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index ebc032d6..c99fe615 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -144,10 +144,11 @@ class GeneralSum(Sum): class Variable(BaseModel): - """A type variable identified by a name.""" + """A type variable identified by a de Bruijn index.""" - t: Literal["Var"] = "Var" - name: str + t: Literal["V"] = "V" + i: int + b: "TypeBound" class Int(BaseModel): @@ -208,6 +209,10 @@ class TypeBound(Enum): Copyable = "C" Any = "A" + @staticmethod + def from_linear(linear: bool) -> "TypeBound": + return TypeBound.Any if linear else TypeBound.Eq + class Opaque(BaseModel): """An opaque operation that can be downcasted by the extensions that define it.""" diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index d8a64c32..ad6e4048 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -13,7 +13,7 @@ CustomCallCompiler, ) from guppy.error import GuppyTypeError, GuppyError -from guppy.gtypes import GuppyType, FunctionType, BoolType +from guppy.gtypes import GuppyType, FunctionType, BoolType, Subst, unify from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV from guppy.nodes import GlobalCall @@ -117,11 +117,11 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: super()._setup(ctx, node, func) self.base_checker._setup(ctx, node, func) - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - expr = self.base_checker.check(args, ty) + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + expr, subst = self.base_checker.check(args, ty) if isinstance(expr, GlobalCall): expr.args = list(reversed(args)) - return expr + return expr, subst def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: expr, ty = self.base_checker.synthesize(args) @@ -141,7 +141,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: f"Builtin method `{self.func.name}` is not supported by Guppy", self.node ) - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: raise GuppyError( f"Builtin method `{self.func.name}` is not supported by Guppy", self.node ) @@ -177,7 +177,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: args, func = self._get_func(args) return func.synthesize_call(args, self.node, self.ctx) - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: args, func = self._get_func(args) return func.check_call(args, ty, self.node, self.ctx) @@ -196,13 +196,14 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: const = with_loc(self.node, ast.Constant(value=is_callable)) return const, BoolType() - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: args, _ = self.synthesize(args) - if not isinstance(ty, BoolType): + subst = unify(ty, BoolType(), {}) + if subst is None: raise GuppyTypeError( f"Expected expression of type `{ty}`, got `bool`", self.node ) - return args + return args, subst class IntTruedivCompiler(CustomCallCompiler): @@ -214,13 +215,13 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # Compile `truediv` using float arithmetic [left, right] = args [left] = Int.__float__.compile_call( - [left], self.dfg, self.graph, self.globals, self.node + [left], [], self.dfg, self.graph, self.globals, self.node ) [right] = Int.__float__.compile_call( - [right], self.dfg, self.graph, self.globals, self.node + [right], [], self.dfg, self.graph, self.globals, self.node ) return Float.__truediv__.compile_call( - [left, right], self.dfg, self.graph, self.globals, self.node + [left, right], [], self.dfg, self.graph, self.globals, self.node ) @@ -236,7 +237,12 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: ) zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) return Float.__ne__.compile_call( - [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node + [args[0], zero.out_port(0)], + [], + self.dfg, + self.graph, + self.globals, + self.node, ) @@ -248,10 +254,10 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # We have: floordiv(x, y) = floor(truediv(x, y)) [div] = Float.__truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) [floor] = Float.__floor__.compile_call( - [div], self.dfg, self.graph, self.globals, self.node + [div], [], self.dfg, self.graph, self.globals, self.node ) return [floor] @@ -264,13 +270,13 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # We have: mod(x, y) = x - (x // y) * y [div] = Float.__floordiv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) [mul] = Float.__mul__.compile_call( - [div, args[1]], self.dfg, self.graph, self.globals, self.node + [div, args[1]], [], self.dfg, self.graph, self.globals, self.node ) [sub] = Float.__sub__.compile_call( - [args[0], mul], self.dfg, self.graph, self.globals, self.node + [args[0], mul], [], self.dfg, self.graph, self.globals, self.node ) return [sub] @@ -283,9 +289,9 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # We have: divmod(x, y) = (div(x, y), mod(x, y)) [div] = Float.__truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) [mod] = Float.__mod__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] From 559469b4b5a6c4f71431b92821b4fa953dea6684 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 16:38:54 +0000 Subject: [PATCH 46/77] Add type args to CoercingChecker --- guppy/prelude/_internal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index ad6e4048..2f09785d 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -99,7 +99,8 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) if isinstance(ty, self.ctx.globals.types["int"]): call = with_loc( - self.node, GlobalCall(func=Int.__float__, args=[args[i]]) + self.node, + GlobalCall(func=Int.__float__, args=[args[i]], type_args=[]), ) args[i] = with_type(self.ctx.globals.types["float"].build(), call) return super().synthesize(args) From 8da2b4af4ef8685cea3833975d9befe2db7104c8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 27 Nov 2023 16:51:10 +0000 Subject: [PATCH 47/77] Make non-linear args copyable --- guppy/hugr/tys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index c99fe615..5595782d 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -211,7 +211,7 @@ class TypeBound(Enum): @staticmethod def from_linear(linear: bool) -> "TypeBound": - return TypeBound.Any if linear else TypeBound.Eq + return TypeBound.Any if linear else TypeBound.Copyable class Opaque(BaseModel): From 9a0647fd8c0253d727df988d90839d6f7f05eb7f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 12:01:27 +0000 Subject: [PATCH 48/77] Protect instantiated tuples from being turned into rows --- guppy/compiler/expr_compiler.py | 31 ++++++++++++++++++++++++++++--- guppy/gtypes.py | 26 +++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 5079b1dd..e43a4c58 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -3,8 +3,9 @@ from guppy.ast_util import AstVisitor, get_type from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction -from guppy.error import InternalGuppyError -from guppy.gtypes import FunctionType, type_to_row, BoolType +from guppy.error import InternalGuppyError, GuppyError +from guppy.gtypes import FunctionType, type_to_row, BoolType, BoundTypeVar, TupleType, \ + Inst, NoneType from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall, TypeApply @@ -81,7 +82,23 @@ def visit_Call(self, node: ast.Call) -> OutPortV: def visit_TypeApply(self, node: TypeApply) -> OutPortV: func = self.visit(node.value) - return self.graph.add_type_apply(func, node.tys, self.dfg.node).out_port(0) + assert isinstance(func.ty, FunctionType) + ta = self.graph.add_type_apply(func, node.tys, self.dfg.node).out_port(0) + + # We have to be very careful here: If we instantiate `foo: forall T. T -> T` + # with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`. + # Normally, this would be represented in Hugr as a function with two output + # ports types A and B. However, when TypeApplying `foo`, we actually get a + # function with a single output port typed `tuple[A, B]`. + # TODO: We would need to do manual monomorphisation in that case to obtain a + # function that returns two ports as expected + if instantiation_needs_unpacking(func.ty, node.tys): + raise GuppyError( + "Generic function instantiations returning rows are not supported yet", + node + ) + + return ta def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: # The only case that is not desugared by the type checker is the `not` operation @@ -106,6 +123,14 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]: return expr.elts if isinstance(expr, ast.Tuple) else [expr] +def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: + """Checks if instantiating a polymorphic makes it return a row.""" + if isinstance(func_ty.returns, BoundTypeVar): + return_ty = inst[func_ty.returns.idx] + return isinstance(return_ty, TupleType) or isinstance(return_ty, NoneType) + return False + + def python_value_to_hugr(v: Any) -> Optional[val.Value]: """Turns a Python value into a Hugr value. diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 7612c3ee..87a6e17b 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -216,7 +216,17 @@ def transform(self, transformer: "TypeTransformer") -> GuppyType: def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": """Instantiates quantified type variables.""" assert len(tys) == len(self.quantified) - inst = Instantiator(tys) + + # Set the `preserve` flag for instantiated tuples and None + preserved_tys: list[GuppyType] = [] + for ty in tys: + if isinstance(ty, TupleType): + ty = TupleType(ty.element_types, preserve=True) + elif isinstance(ty, NoneType): + ty = NoneType(preserve=True) + preserved_tys.append(ty) + + inst = Instantiator(preserved_tys) return FunctionType( [ty.transform(inst) for ty in self.args], self.returns.transform(inst), @@ -236,6 +246,11 @@ def unquantified(self) -> tuple["FunctionType", Sequence[FreeTypeVar]]: class TupleType(GuppyType): element_types: Sequence[GuppyType] + # Flag to avoid turning the tuple into row when calling `type_to_row()`. This is + # used to make sure that type vars instantiated to tuples are not broken up into + # rows when generating a Hugr + preserve: bool = field(default=False, compare=False) + name: ClassVar[Literal["tuple"]] = "tuple" @staticmethod @@ -310,6 +325,11 @@ class NoneType(GuppyType): name: ClassVar[Literal["None"]] = "None" linear: bool = False + # Flag to avoid turning the type into a row when calling `type_to_row()`. This is + # used to make sure that type vars instantiated to Nones are not broken up into + # empty rows when generating a Hugr + preserve: bool = field(default=False, compare=False) + @staticmethod def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: if len(args) > 0: @@ -536,8 +556,8 @@ def row_to_type(row: Sequence[GuppyType]) -> GuppyType: def type_to_row(ty: GuppyType) -> Sequence[GuppyType]: """Turns a type into a row of types by unpacking top-level tuples.""" - if isinstance(ty, NoneType): + if isinstance(ty, NoneType) and not ty.preserve: return [] - if isinstance(ty, TupleType): + if isinstance(ty, TupleType) and not ty.preserve: return ty.element_types return [ty] From 05fe2e5a95a86fdc8a97a731d2cee029dc308e5b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 12:16:31 +0000 Subject: [PATCH 49/77] Add more tests --- tests/error/poly_errors/pass_poly_free.err | 7 + tests/error/poly_errors/pass_poly_free.py | 27 +++ tests/error/poly_errors/right_to_left.err | 7 + tests/error/poly_errors/right_to_left.py | 25 +++ tests/error/test_poly_errors.py | 15 ++ tests/integration/test_poly.py | 210 +++++++++++++++++++++ 6 files changed, 291 insertions(+) create mode 100644 tests/error/poly_errors/pass_poly_free.err create mode 100644 tests/error/poly_errors/pass_poly_free.py create mode 100644 tests/error/poly_errors/right_to_left.err create mode 100644 tests/error/poly_errors/right_to_left.py create mode 100644 tests/error/test_poly_errors.py create mode 100644 tests/integration/test_poly.py diff --git a/tests/error/poly_errors/pass_poly_free.err b/tests/error/poly_errors/pass_poly_free.err new file mode 100644 index 00000000..5154e2df --- /dev/null +++ b/tests/error/poly_errors/pass_poly_free.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:24 + +22: @guppy(module) +23: def main() -> None: +24: foo(bar) + ^^^ +GuppyTypeInferenceError: Expected argument of type `?T -> ?T`, got `forall T. T -> T`. Couldn't infer an instantiation for type variable `T` (higher-rank polymorphic types are not supported) diff --git a/tests/error/poly_errors/pass_poly_free.py b/tests/error/poly_errors/pass_poly_free.py new file mode 100644 index 00000000..d88a6b83 --- /dev/null +++ b/tests/error/poly_errors/pass_poly_free.py @@ -0,0 +1,27 @@ +from typing import Callable + +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(f: Callable[[T], T]) -> None: + ... + + +@guppy.declare(module) +def bar(x: T) -> T: + ... + + +@guppy(module) +def main() -> None: + foo(bar) + + +module.compile() diff --git a/tests/error/poly_errors/right_to_left.err b/tests/error/poly_errors/right_to_left.err new file mode 100644 index 00000000..b488be7d --- /dev/null +++ b/tests/error/poly_errors/right_to_left.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:22 + +20: @guppy(module) +21: def main() -> None: +22: bar(foo(), 42) + ^^^^^ +GuppyTypeInferenceError: Cannot infer type variable in expression of type `?T` diff --git a/tests/error/poly_errors/right_to_left.py b/tests/error/poly_errors/right_to_left.py new file mode 100644 index 00000000..1fec068e --- /dev/null +++ b/tests/error/poly_errors/right_to_left.py @@ -0,0 +1,25 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo() -> T: + ... + + +@guppy.declare(module) +def bar(x: T, y: T) -> None: + ... + + +@guppy(module) +def main() -> None: + bar(foo(), 42) + + +module.compile() diff --git a/tests/error/test_poly_errors.py b/tests/error/test_poly_errors.py new file mode 100644 index 00000000..a140769f --- /dev/null +++ b/tests/error/test_poly_errors.py @@ -0,0 +1,15 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "poly_errors" +files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_type_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py new file mode 100644 index 00000000..29728279 --- /dev/null +++ b/tests/integration/test_poly.py @@ -0,0 +1,210 @@ +from typing import Callable + +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +def test_id(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> int: + return foo(x) + + validate(module.compile()) + + +def test_id_nested(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> int: + return foo(foo(foo(x))) + + validate(module.compile()) + + +def test_use_twice(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int, y: bool) -> None: + foo(x) + foo(y) + + validate(module.compile()) + + +def test_define_twice(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy.declare(module) + def bar(x: T) -> T: # Reuse same type var! + ... + + @guppy(module) + def main(x: bool, y: float) -> None: + foo(x) + foo(y) + + validate(module.compile()) + + +def test_return_tuple_implicit(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> tuple[int, int]: + return foo((x, 0)) + + validate(module.compile()) + + +def test_same_args(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T, y: T) -> None: + ... + + @guppy(module) + def main(x: int) -> None: + foo(x, 42) + + validate(module.compile()) + + +def test_different_args(validate): + module = GuppyModule("test") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: S, y: T, z: tuple[S, T]) -> T: + ... + + @guppy(module) + def main(x: int, y: float) -> float: + return foo(x, y, (x, y)) + foo(y, 42.0, (0.0, y)) + + validate(module.compile()) + + +def test_infer_basic(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy(module) + def main() -> None: + x: int = foo() + + validate(module.compile()) + + +def test_infer_nested(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy.declare(module) + def bar(x: T) -> T: + ... + + @guppy(module) + def main() -> None: + x: int = bar(foo()) + + validate(module.compile()) + + +def test_infer_left_to_right(validate): + module = GuppyModule("test") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy.declare(module) + def bar(x: T, y: T, z: S, a: tuple[T, S]) -> None: + ... + + @guppy(module) + def main() -> None: + bar(42, foo(), False, foo()) + + validate(module.compile()) + + +def test_pass_poly_basic(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(f: Callable[[T], T]) -> None: + ... + + @guppy.declare(module) + def bar(x: int) -> int: + ... + + @guppy(module) + def main() -> None: + foo(bar) + + validate(module.compile()) + + +def test_pass_poly_cross(validate): + module = GuppyModule("test") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(f: Callable[[S], int]) -> None: + ... + + @guppy.declare(module) + def bar(x: bool) -> T: + ... + + @guppy(module) + def main() -> None: + foo(bar) + + validate(module.compile()) + From 9bbfc865201670ae86e9e91bea391c870949390e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 12:17:00 +0000 Subject: [PATCH 50/77] Run formatting --- guppy/compiler/core.py | 2 +- guppy/compiler/expr_compiler.py | 13 ++++++++++--- guppy/declared.py | 2 +- guppy/gtypes.py | 3 +-- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 2d4c9cdd..9ddc4280 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstNode from guppy.checker.core import Variable, CallableVariable -from guppy.gtypes import FunctionType, GuppyType, Inst +from guppy.gtypes import FunctionType, Inst from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index e43a4c58..bf145466 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -4,8 +4,15 @@ from guppy.ast_util import AstVisitor, get_type from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction from guppy.error import InternalGuppyError, GuppyError -from guppy.gtypes import FunctionType, type_to_row, BoolType, BoundTypeVar, TupleType, \ - Inst, NoneType +from guppy.gtypes import ( + FunctionType, + type_to_row, + BoolType, + BoundTypeVar, + TupleType, + Inst, + NoneType, +) from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall, TypeApply @@ -95,7 +102,7 @@ def visit_TypeApply(self, node: TypeApply) -> OutPortV: if instantiation_needs_unpacking(func.ty, node.tys): raise GuppyError( "Generic function instantiations returning rows are not supported yet", - node + node, ) return ta diff --git a/guppy/declared.py b/guppy/declared.py index 72002744..3dbea36b 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstNode, has_empty_body from guppy.checker.core import Globals, Context -from guppy.checker.expr_checker import check_call, synthesize_call, instantiate_poly +from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 87a6e17b..fc467967 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -1,5 +1,4 @@ import ast -import functools import itertools from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -14,7 +13,7 @@ ) import guppy.hugr.tys as tys -from guppy.ast_util import AstNode, set_location_from +from guppy.ast_util import AstNode if TYPE_CHECKING: from guppy.checker.core import Globals From b1bdbe6f72a71f6cce12adb98ac095451aa126a7 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 12:32:24 +0000 Subject: [PATCH 51/77] Store free vars in set instead of dict --- guppy/checker/expr_checker.py | 22 ++++++------- guppy/gtypes.py | 60 ++++++++++++++++------------------- 2 files changed, 38 insertions(+), 44 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 4c80d5df..efc6cc83 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -93,7 +93,7 @@ def check( # When checking against a variable, we have to synthesize if isinstance(ty, FreeTypeVar): expr, syn_ty = self._synthesize(expr, allow_free_vars=False) - return with_type(syn_ty, expr), {ty.id: syn_ty} + return with_type(syn_ty, expr), {ty: syn_ty} # Otherwise, invoke the visitor old_kind = self._kind @@ -344,21 +344,21 @@ def check_type_against( raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) # Check that we have found a valid instantiation for all quantified vars for i, v in enumerate(free_vars): - if v.id not in subst: + if v not in subst: raise GuppyTypeInferenceError( f"Expected {kind} of type `{exp}`, got `{act}`. Couldn't infer an " f"instantiation for type variable `{act.quantified[i]}` (higher-" "rank polymorphic types are not supported)", node, ) - if subst[v.id].free_vars: + if subst[v].free_vars: raise GuppyTypeError( f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " - f"type variable `{act.quantified[i]}` with type `{subst[v.id]}` " + f"type variable `{act.quantified[i]}` with type `{subst[v]}` " "containing free variables", node, ) - inst = [subst[v.id] for v in free_vars] + inst = [subst[v] for v in free_vars] subst = {v: t for v, t in subst.items() if v in exp.free_vars} return subst, inst @@ -407,12 +407,10 @@ def type_check_args( # If the argument check succeeded, this means that we must have found instantiations # for all unification variables occurring in the argument types - assert all( - set.issubset(set(arg.free_vars.keys()), subst.keys()) for arg in func_ty.args - ) + assert all(set.issubset(arg.free_vars, subst.keys()) for arg in func_ty.args) # We also have to check that we found instantiations for all vars in the return type - if not set.issubset(set(func_ty.returns.free_vars.keys()), subst.keys()): + if not set.issubset(func_ty.returns.free_vars, subst.keys()): raise GuppyTypeInferenceError( f"Cannot infer type variable in expression of type " f"`{func_ty.returns.substitute(subst)}`", @@ -440,7 +438,7 @@ def synthesize_call( # Success implies that the substitution is closed assert all(not t.free_vars for t in subst.values()) - inst = [subst[v.id] for v in free_vars] + inst = [subst[v] for v in free_vars] return args, unquantified.returns.substitute(subst), inst @@ -507,7 +505,7 @@ def check_call( # Also make sure we found an instantiation for all free vars in the type we're # checking against - if not set.issubset(set(ty.free_vars.keys()), subst.keys()): + if not set.issubset(ty.free_vars, subst.keys()): raise GuppyTypeInferenceError( f"Expected expression of type `{ty}`, got " f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", @@ -516,7 +514,7 @@ def check_call( # Success implies that the substitution is closed assert all(not t.free_vars for t in subst.values()) - inst = [subst[v.id] for v in free_vars] + inst = [subst[v] for v in free_vars] subst = {v: t for v, t in subst.items() if v in ty.free_vars} return args, subst, inst diff --git a/guppy/gtypes.py b/guppy/gtypes.py index fc467967..431f7805 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -9,7 +9,7 @@ Mapping, Iterator, ClassVar, - Literal, + Literal, Set, ) import guppy.hugr.tys as tys @@ -19,20 +19,7 @@ from guppy.checker.core import Globals -@dataclass(frozen=True) -class TypeVarId: - """Identifier for free type variables.""" - - id: int - - _id_generator: ClassVar[Iterator[int]] = itertools.count() - - @classmethod - def new(cls) -> "TypeVarId": - return TypeVarId(next(cls._id_generator)) - - -Subst = dict[TypeVarId, "GuppyType"] +Subst = dict["FreeTypeVar", "GuppyType"] Inst = Sequence["GuppyType"] @@ -46,7 +33,7 @@ class GuppyType(ABC): name: ClassVar[str] # Cache for free variables - _free_vars: Mapping["TypeVarId", "FreeTypeVar"] = field(init=False, repr=False) + _free_vars: Set["FreeTypeVar"] = field(init=False, repr=False) def __post_init__(self) -> None: # Make sure that we don't have higher-rank polymorphic types @@ -60,9 +47,9 @@ def __post_init__(self) -> None: # Compute free variables if isinstance(self, FreeTypeVar): - vs = {self.id: self} + vs = {self} else: - vs = {} + vs = set() for arg in self.type_args: vs |= arg.free_vars object.__setattr__(self, "_free_vars", vs) @@ -91,7 +78,7 @@ def transform(self, transformer: "TypeTransformer") -> "GuppyType": pass @property - def free_vars(self) -> Mapping["TypeVarId", "FreeTypeVar"]: + def free_vars(self) -> Set["FreeTypeVar"]: return self._free_vars def substitute(self, s: Subst) -> "GuppyType": @@ -128,14 +115,23 @@ def to_hugr(self) -> tys.SimpleType: @dataclass(frozen=True) class FreeTypeVar(GuppyType): - """Free type variable, identified with a globally unique id.""" + """Free type variable, identified with a globally unique id. - id: TypeVarId + Serves as an existential variable for unification. + """ + + id: int display_name: str linear: bool = False name: ClassVar[Literal["FreeTypeVar"]] = "FreeTypeVar" + _id_generator: ClassVar[Iterator[int]] = itertools.count() + + @classmethod + def new(cls, display_name: str, linear: bool) -> "FreeTypeVar": + return FreeTypeVar(next(cls._id_generator), display_name, linear) + @staticmethod def build(*rgs: GuppyType, node: Optional[AstNode] = None) -> GuppyType: raise NotImplementedError() @@ -150,6 +146,9 @@ def transform(self, transformer: "TypeTransformer") -> GuppyType: def __str__(self) -> str: return "?" + self.display_name + def __hash__(self) -> int: + return self.id + def to_hugr(self) -> tys.SimpleType: from guppy.error import InternalGuppyError @@ -234,10 +233,7 @@ def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": def unquantified(self) -> tuple["FunctionType", Sequence[FreeTypeVar]]: """Replaces all quantified variables with free type variables.""" - inst = [ - FreeTypeVar(TypeVarId.new(), v.display_name, v.linear) - for v in self.quantified - ] + inst = [FreeTypeVar.new(v.display_name, v.linear) for v in self.quantified] return self.instantiate(inst), inst @@ -402,7 +398,7 @@ def __init__(self, subst: Subst) -> None: def transform(self, ty: GuppyType) -> Optional[GuppyType]: if isinstance(ty, FreeTypeVar): - return self.subst.get(ty.id, None) + return self.subst.get(ty, None) return None @@ -450,13 +446,13 @@ def unify(s: GuppyType, t: GuppyType, subst: Optional[Subst]) -> Optional[Subst] def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Optional[Subst]: """Helper function for unification of type variables.""" - if var.id in subst: - return unify(subst[var.id], t, subst) - if isinstance(t, FreeTypeVar) and t.id in subst: - return unify(var, subst[t.id], subst) - if var.id in t.free_vars: + if var in subst: + return unify(subst[var], t, subst) + if isinstance(t, FreeTypeVar) and t in subst: + return unify(var, subst[t], subst) + if var in t.free_vars: return None - return {var.id: t, **subst} + return {var: t, **subst} def type_from_ast( From b52230bcf1e5527545dab90d4b7b506101e5db7f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 13:31:09 +0000 Subject: [PATCH 52/77] Respect linearity --- guppy/checker/expr_checker.py | 27 ++++++++++++ guppy/gtypes.py | 4 +- tests/error/poly_errors/non_linear1.err | 7 ++++ tests/error/poly_errors/non_linear1.py | 24 +++++++++++ tests/error/poly_errors/non_linear2.err | 7 ++++ tests/error/poly_errors/non_linear2.py | 26 ++++++++++++ tests/integration/test_poly.py | 55 +++++++++++++++++++++++++ 7 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 tests/error/poly_errors/non_linear1.err create mode 100644 tests/error/poly_errors/non_linear1.py create mode 100644 tests/error/poly_errors/non_linear2.err create mode 100644 tests/error/poly_errors/non_linear2.py diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index efc6cc83..d5985bbc 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -133,6 +133,7 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: # Otherwise, it must be a function as a higher-order value if isinstance(func_ty, FunctionType): args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx) + check_inst(func_ty, inst, node) node.func = instantiate_poly(node.func, func_ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(func_ty, "__call__"): @@ -360,6 +361,10 @@ def check_type_against( ) inst = [subst[v] for v in free_vars] subst = {v: t for v, t in subst.items() if v in exp.free_vars} + + # Finally, check that the instantiation respects the linearity requirements + check_inst(act, inst, node) + return subst, inst # Otherwise, we know that `act` has no free type vars, so unification is trivial @@ -439,6 +444,10 @@ def synthesize_call( # Success implies that the substitution is closed assert all(not t.free_vars for t in subst.values()) inst = [subst[v] for v in free_vars] + + # Finally, check that the instantiation respects the linearity requirements + check_inst(func_ty, inst, node) + return args, unquantified.returns.substitute(subst), inst @@ -516,9 +525,27 @@ def check_call( assert all(not t.free_vars for t in subst.values()) inst = [subst[v] for v in free_vars] subst = {v: t for v, t in subst.items() if v in ty.free_vars} + + # Finally, check that the instantiation respects the linearity requirements + check_inst(func_ty, inst, node) + return args, subst, inst +def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None: + """Checks if an instantiation is valid. + + Makes sure that the linearity requirements are satisfied. + """ + for var, ty in zip(func_ty.quantified, inst): + if not var.linear and ty.linear: + raise GuppyTypeError( + f"Cannot instantiate non-linear type variable `{var}` in type " + f"`{func_ty}` with linear type `{ty}`", + node, + ) + + def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: """Instantiates quantified type arguments in a function.""" assert len(ty.quantified) == len(inst) diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 431f7805..04f84623 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -6,10 +6,10 @@ Optional, Sequence, TYPE_CHECKING, - Mapping, Iterator, ClassVar, - Literal, Set, + Literal, + Set, ) import guppy.hugr.tys as tys diff --git a/tests/error/poly_errors/non_linear1.err b/tests/error/poly_errors/non_linear1.err new file mode 100644 index 00000000..a6a2ad1d --- /dev/null +++ b/tests/error/poly_errors/non_linear1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:21 + +19: @guppy(module) +20: def main(q: Qubit) -> None: +21: foo(q) + ^^^^^^ +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. T -> None` with linear type `Qubit` diff --git a/tests/error/poly_errors/non_linear1.py b/tests/error/poly_errors/non_linear1.py new file mode 100644 index 00000000..728dccd7 --- /dev/null +++ b/tests/error/poly_errors/non_linear1.py @@ -0,0 +1,24 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import Qubit + +import guppy.prelude.quantum as quantum + +module = GuppyModule("test") +module.load(quantum) + + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T) -> None: + ... + + +@guppy(module) +def main(q: Qubit) -> None: + foo(q) + + +module.compile() diff --git a/tests/error/poly_errors/non_linear2.err b/tests/error/poly_errors/non_linear2.err new file mode 100644 index 00000000..63c18f16 --- /dev/null +++ b/tests/error/poly_errors/non_linear2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:23 + +21: @guppy(module) +22: def main() -> None: +23: foo(h) + ^^^^^^ +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. T -> T -> None` with linear type `Qubit` diff --git a/tests/error/poly_errors/non_linear2.py b/tests/error/poly_errors/non_linear2.py new file mode 100644 index 00000000..b46238c3 --- /dev/null +++ b/tests/error/poly_errors/non_linear2.py @@ -0,0 +1,26 @@ +from typing import Callable + +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import h + +import guppy.prelude.quantum as quantum + +module = GuppyModule("test") +module.load(quantum) + + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: Callable[[T], T]) -> None: + ... + + +@guppy(module) +def main() -> None: + foo(h) + + +module.compile() diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 29728279..cc0f63f1 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -2,6 +2,9 @@ from guppy.decorator import guppy from guppy.module import GuppyModule +from guppy.prelude.quantum import Qubit + +import guppy.prelude.quantum as quantum def test_id(validate): @@ -208,3 +211,55 @@ def main() -> None: validate(module.compile()) + +def test_linear(validate): + module = GuppyModule("test") + module.load(quantum) + T = guppy.type_var(module, "T", linear=True) + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(q: Qubit) -> Qubit: + return foo(q) + + validate(module.compile()) + + +def test_pass_nonlinear(validate): + module = GuppyModule("test") + module.load(quantum) + T = guppy.type_var(module, "T", linear=True) + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> None: + foo(x) + + validate(module.compile()) + + +def test_pass_linear(validate): + module = GuppyModule("test") + module.load(quantum) + T = guppy.type_var(module, "T", linear=True) + + @guppy.declare(module) + def foo(f: Callable[[T], T]) -> None: + ... + + @guppy.declare(module) + def bar(q: Qubit) -> Qubit: + ... + + @guppy(module) + def main() -> None: + foo(bar) + + validate(module.compile()) + From 506398f7355ba31af0fd5a4195d7f7fe43f73690 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 13:34:52 +0000 Subject: [PATCH 53/77] Add more tests --- tests/error/poly_errors/define.err | 6 ++++++ tests/error/poly_errors/define.py | 15 +++++++++++++++ tests/integration/test_poly.py | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+) create mode 100644 tests/error/poly_errors/define.err create mode 100644 tests/error/poly_errors/define.py diff --git a/tests/error/poly_errors/define.err b/tests/error/poly_errors/define.err new file mode 100644 index 00000000..a38e2c1a --- /dev/null +++ b/tests/error/poly_errors/define.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:11 + +9: @guppy(module) +10: def main(x: T) -> T: + ^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Generic function definitions are not supported yet diff --git a/tests/error/poly_errors/define.py b/tests/error/poly_errors/define.py new file mode 100644 index 00000000..433c71d9 --- /dev/null +++ b/tests/error/poly_errors/define.py @@ -0,0 +1,15 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy(module) +def main(x: T) -> T: + return x + + +module.compile() diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index cc0f63f1..f7ea085a 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -263,3 +263,22 @@ def main() -> None: validate(module.compile()) + +def test_higher_order_value(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy.declare(module) + def bar(x: T) -> T: + ... + + @guppy(module) + def main(b: bool) -> int: + f = foo if b else bar + return f(42) + + validate(module.compile()) From 61e2cba8404164682f663aeaccca8d2deaa15965 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 Nov 2023 13:39:55 +0000 Subject: [PATCH 54/77] Detect unused linear expressions --- guppy/checker/stmt_checker.py | 4 +++- tests/error/linear_errors/unused_expr.err | 7 +++++++ tests/error/linear_errors/unused_expr.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 tests/error/linear_errors/unused_expr.err create mode 100644 tests/error/linear_errors/unused_expr.py diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 5b12dc58..d8321ff7 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -92,7 +92,9 @@ def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: def visit_Expr(self, node: ast.Expr) -> ast.stmt: # An expression statement where the return value is discarded - node.value, _ = self._synth_expr(node.value) + node.value, ty = self._synth_expr(node.value) + if ty.linear: + raise GuppyTypeError(f"Value with linear type `{ty}` is not used", node) return node def visit_Return(self, node: ast.Return) -> ast.stmt: diff --git a/tests/error/linear_errors/unused_expr.err b/tests/error/linear_errors/unused_expr.err new file mode 100644 index 00000000..6a0ae95e --- /dev/null +++ b/tests/error/linear_errors/unused_expr.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(q: Qubit) -> None: +13: h(q) + ^^^^ +GuppyTypeError: Value with linear type `Qubit` is not used diff --git a/tests/error/linear_errors/unused_expr.py b/tests/error/linear_errors/unused_expr.py new file mode 100644 index 00000000..6ad60a82 --- /dev/null +++ b/tests/error/linear_errors/unused_expr.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(q: Qubit) -> None: + h(q) + + +module.compile() From aface01384b4e88697ce17ad1e26567351b229ac Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 Nov 2023 09:17:39 +0000 Subject: [PATCH 55/77] Fix UnknownFunctionType --- guppy/error.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/guppy/error.py b/guppy/error.py index e8b56bde..62dcf4ff 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -3,10 +3,10 @@ import sys import textwrap from dataclasses import dataclass, field -from typing import Optional, Any, Sequence, Callable, TypeVar, cast +from typing import Optional, Any, Sequence, Callable, TypeVar, cast, Set from guppy.ast_util import AstNode, get_line_offset, get_file, get_source -from guppy.gtypes import GuppyType, FunctionType +from guppy.gtypes import GuppyType, FunctionType, BoundTypeVar, FreeTypeVar from guppy.hugr.hugr import OutPortV, Node @@ -128,6 +128,15 @@ def returns(self) -> GuppyType: def args_names(self) -> Optional[Sequence[str]]: raise InternalGuppyError("Tried to access unknown function type") + @property + def quantified(self) -> Sequence[BoundTypeVar]: + raise InternalGuppyError("Tried to access unknown function type") + + @property + def free_vars(self) -> Set[FreeTypeVar]: + return set() + + def format_source_location( loc: ast.AST, From 7269ea322748bf058ac6009ad5aa8580d7f6cc6a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 Nov 2023 09:20:51 +0000 Subject: [PATCH 56/77] Run formatting --- guppy/error.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guppy/error.py b/guppy/error.py index 62dcf4ff..d80a025e 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -137,7 +137,6 @@ def free_vars(self) -> Set[FreeTypeVar]: return set() - def format_source_location( loc: ast.AST, num_lines: int = 3, From 302e2aa7d50bc226079552b6f529c2313a3da3f7 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 Nov 2023 10:55:25 +0000 Subject: [PATCH 57/77] Improve synthesize_call docstring --- guppy/checker/expr_checker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 747d32c0..7d8eefac 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -277,7 +277,10 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[list[ast.expr], GuppyType]: - """Synthesizes the return type of a function call""" + """Synthesizes the return type of a function call. + + Also returns desugared versions of the arguments with type annotations. + """ check_num_args(len(func_ty.args), len(args), node) for i, arg in enumerate(args): args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") From 145c6aaa7414c84d72614b13e54c6eb655f69a91 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 Nov 2023 11:27:52 +0000 Subject: [PATCH 58/77] Add module docstrings --- guppy/checker/cfg_checker.py | 6 ++++++ guppy/checker/expr_checker.py | 22 ++++++++++++++++++++++ guppy/checker/func_checker.py | 7 +++++++ guppy/checker/stmt_checker.py | 10 ++++++++++ 4 files changed, 45 insertions(+) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 1fbb0521..85eaeb9a 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -1,3 +1,9 @@ +"""Type checking code for control-flow graphs + +Operates on CFGs produced by the `CFGBuilder`. Produces a `CheckedCFG` consisting of +`CheckedBB`s with inferred type signatures. +""" + import collections from dataclasses import dataclass from typing import Sequence diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 7d8eefac..0d1f8918 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -1,3 +1,25 @@ +"""Type checking and synthesizing code for expressions. + +Operates on expressions in a basic block after CFG construction. In particular, we +assume that expressions that involve control flow (i.e. short-circuiting and ternary +expressions) have been removed during CFG construction. + +Furthermore, we assume that assignment expressions with the walrus operator := have +been turned into regular assignments and are no longer present. As a result, expressions +are assumed to be side effect free, in the sense that they do not modify the variables +available in the type checking context. + +We may alter/desugar AST nodes during type checking. In particular, we turn `ast.Name` +nodes into either `LocalName` or `GlobalName` nodes and `ast.Call` nodes are turned into +`LocalCall` or `GlobalCall` nodes. Furthermore, all nodes in the resulting AST are +annotated with their type. + +Expressions can be checked against a given type by the `ExprChecker`, raising an Error +if the expressions doesn't have the expected type. Checking is used for annotated +assignments, return values, and function arguments. Alternatively, the `ExprSynthesizer` +can be used to infer a type for an expression. +""" + import ast from typing import Optional, Union, NoReturn, Any diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 87b22b8f..d7123028 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -1,3 +1,10 @@ +"""Type checking code for top-level and nested function definitions. + +For top-level functions, we take the `DefinedFunction` containing the `ast.FunctionDef` +node straight from the Python AST. We build a CFG, check it, and return a +`CheckedFunction` containing a `CheckedCFG` with type annotations. +""" + import ast from dataclasses import dataclass diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index d8321ff7..8f35f427 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -1,3 +1,13 @@ +"""Type checking code for statements. + +Operates on statements in a basic block after CFG construction. In particular, we +assume that statements involving control flow (i.e. if, while, break, and return +statements) have been removed during CFG construction. + +After checking, we return a desugared statement where all sub-expression have beem type +annotated. +""" + import ast from typing import Sequence From 6b833cbf3a861fd107514c992635288f861aca80 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 Nov 2023 11:31:25 +0000 Subject: [PATCH 59/77] Rename cfg to containing_cfg --- guppy/cfg/bb.py | 2 +- guppy/checker/cfg_checker.py | 4 ++-- guppy/checker/func_checker.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index f5831dbf..95c9f63a 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -46,7 +46,7 @@ class BB(ABC): idx: int # Pointer to the CFG that contains this node - cfg: "BaseCFG[Self]" + containing_cfg: "BaseCFG[Self]" # AST statements contained in this BB statements: list[BBStatement] = field(default_factory=list) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 85eaeb9a..67868bec 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -119,7 +119,7 @@ def check_bb( return_ty: GuppyType, globals: Globals, ) -> CheckedBB: - cfg = bb.cfg + cfg = bb.containing_cfg # For the entry BB we have to separately check that all used variables are # defined. For all other BBs, this will be checked when compiling a predecessor. @@ -217,6 +217,6 @@ def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: raise GuppyError( f"{ident} can refer to different types: " f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", - bb.cfg.live_before[bb][v1.name].vars.used[v1.name], + bb.containing_cfg.live_before[bb][v1.name].vars.used[v1.name], [v1.defined_at, v2.defined_at], ) diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index d7123028..b2de26f0 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -83,7 +83,7 @@ def check_nested_func_def( cfg = func_def.cfg # Find captured variables - parent_cfg = bb.cfg + parent_cfg = bb.containing_cfg def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] cfg.analyze(def_ass_before, maybe_ass_before) From 29c6009b8a38b8c0dd6d6ccdbd039ca3bec18507 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 Nov 2023 11:33:29 +0000 Subject: [PATCH 60/77] Improve check for entry BB --- guppy/checker/cfg_checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 67868bec..77f2fa51 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -123,7 +123,8 @@ def check_bb( # For the entry BB we have to separately check that all used variables are # defined. For all other BBs, this will be checked when compiling a predecessor. - if len(bb.predecessors) == 0: + if bb == cfg.entry_bb: + assert len(bb.predecessors) == 0 for x, use in bb.vars.used.items(): if x not in cfg.ass_before[bb] and x not in globals.values: raise GuppyError(f"Variable `{x}` is not defined", use) From f982901fb0570ef8ed6dadb6814eeafcebc4937b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:27:39 +0000 Subject: [PATCH 61/77] Use booly value of list --- guppy/checker/func_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index b2de26f0..00372da9 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -126,7 +126,7 @@ def check_nested_func_def( # Check if the body contains a recursive occurrence of the function name if func_def.name in cfg.live_before[cfg.entry_bb]: - if len(captured) == 0: + if not captured: # If there are no captured vars, we treat the function like a global name func = DefinedFunction(func_def.name, func_ty, func_def, None) globals = ctx.globals | Globals({func_def.name: func}, {}) From 55d72715951f15dfc17eba0e53c57f24ea20ccfe Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:28:01 +0000 Subject: [PATCH 62/77] Use contextlib suppress --- guppy/checker/expr_checker.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 0d1f8918..eeb6b2f9 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -21,6 +21,7 @@ """ import ast +from contextlib import suppress from typing import Optional, Union, NoReturn, Any from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt @@ -212,16 +213,12 @@ def _synthesize_binary( right_expr, right_ty = self.synthesize(right_expr) if func := self.ctx.globals.get_instance_func(left_ty, lop): - try: + with suppress(GuppyError): return func.synthesize_call([left_expr, right_expr], node, self.ctx) - except GuppyError: - pass if func := self.ctx.globals.get_instance_func(right_ty, rop): - try: + with suppress(GuppyError): return func.synthesize_call([right_expr, left_expr], node, self.ctx) - except GuppyError: - pass raise GuppyTypeError( f"Binary operator `{display_name}` not defined for arguments of type " From 6e3c2f340d3a7f7fc50052df4b40f3da40564f5a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:28:55 +0000 Subject: [PATCH 63/77] Fix typo --- guppy/checker/stmt_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 8f35f427..a5503155 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -4,7 +4,7 @@ assume that statements involving control flow (i.e. if, while, break, and return statements) have been removed during CFG construction. -After checking, we return a desugared statement where all sub-expression have beem type +After checking, we return a desugared statement where all sub-expression have been type annotated. """ From dacc31f1b366f6f5f2a9305408160409c246828c Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:31:11 +0000 Subject: [PATCH 64/77] Clarify comment --- guppy/checker/func_checker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 00372da9..e0695824 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -124,7 +124,9 @@ def check_nested_func_def( ] globals = ctx.globals - # Check if the body contains a recursive occurrence of the function name + # Check if the body contains a free (recursive) occurrence of the function name. + # By checking if the name is free at the entry BB, we avoid false positives when + # a user shadows the name with a local variable if func_def.name in cfg.live_before[cfg.entry_bb]: if not captured: # If there are no captured vars, we treat the function like a global name From a872eed98697003f117054bcc75265eee1210a69 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:35:43 +0000 Subject: [PATCH 65/77] Clarify docstring --- guppy/checker/expr_checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index eeb6b2f9..3d37f790 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -14,8 +14,8 @@ `LocalCall` or `GlobalCall` nodes. Furthermore, all nodes in the resulting AST are annotated with their type. -Expressions can be checked against a given type by the `ExprChecker`, raising an Error -if the expressions doesn't have the expected type. Checking is used for annotated +Expressions can be checked against a given type by the `ExprChecker`, raising a type +error if the expressions doesn't have the expected type. Checking is used for annotated assignments, return values, and function arguments. Alternatively, the `ExprSynthesizer` can be used to infer a type for an expression. """ From 43e25e39e0aab93323a6e0e99bb5f7873f05858a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:43:08 +0000 Subject: [PATCH 66/77] Align op table --- guppy/checker/expr_checker.py | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 3d37f790..768e6b25 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -32,34 +32,34 @@ # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { - ast.UAdd: ("__pos__", "+"), - ast.USub: ("__neg__", "-"), + ast.UAdd: ("__pos__", "+"), + ast.USub: ("__neg__", "-"), ast.Invert: ("__invert__", "~"), -} +} # fmt: skip # Mapping from binary AST op to left dunder method, right dunder method and display name AstOp = Union[ast.operator, ast.cmpop] binary_table: dict[type[AstOp], tuple[str, str, str]] = { - ast.Add: ("__add__", "__radd__", "+"), - ast.Sub: ("__sub__", "__rsub__", "-"), - ast.Mult: ("__mul__", "__rmul__", "*"), - ast.Div: ("__truediv__", "__rtruediv__", "/"), + ast.Add: ("__add__", "__radd__", "+"), + ast.Sub: ("__sub__", "__rsub__", "-"), + ast.Mult: ("__mul__", "__rmul__", "*"), + ast.Div: ("__truediv__", "__rtruediv__", "/"), ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), - ast.Mod: ("__mod__", "__rmod__", "%"), - ast.Pow: ("__pow__", "__rpow__", "**"), - ast.LShift: ("__lshift__", "__rlshift__", "<<"), - ast.RShift: ("__rshift__", "__rrshift__", ">>"), - ast.BitOr: ("__or__", "__ror__", "|"), - ast.BitXor: ("__xor__", "__rxor__", "^"), - ast.BitAnd: ("__and__", "__rand__", "&"), - ast.MatMult: ("__matmul__", "__rmatmul__", "@"), - ast.Eq: ("__eq__", "__eq__", "=="), - ast.NotEq: ("__neq__", "__neq__", "!="), - ast.Lt: ("__lt__", "__gt__", "<"), - ast.LtE: ("__le__", "__ge__", "<="), - ast.Gt: ("__gt__", "__lt__", ">"), - ast.GtE: ("__ge__", "__le__", ">="), -} + ast.Mod: ("__mod__", "__rmod__", "%"), + ast.Pow: ("__pow__", "__rpow__", "**"), + ast.LShift: ("__lshift__", "__rlshift__", "<<"), + ast.RShift: ("__rshift__", "__rrshift__", ">>"), + ast.BitOr: ("__or__", "__ror__", "|"), + ast.BitXor: ("__xor__", "__rxor__", "^"), + ast.BitAnd: ("__and__", "__rand__", "&"), + ast.MatMult: ("__matmul__", "__rmatmul__", "@"), + ast.Eq: ("__eq__", "__eq__", "=="), + ast.NotEq: ("__neq__", "__neq__", "!="), + ast.Lt: ("__lt__", "__gt__", "<"), + ast.LtE: ("__le__", "__ge__", "<="), + ast.Gt: ("__gt__", "__lt__", ">"), + ast.GtE: ("__ge__", "__le__", ">="), +} # fmt: skip class ExprChecker(AstVisitor[ast.expr]): From 4de95e8c7e0ed73fd37921fa846956a08a93af9b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 09:45:33 +0000 Subject: [PATCH 67/77] Specify type: ignore --- guppy/checker/expr_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 768e6b25..01662a1b 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -114,7 +114,7 @@ def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: node.elts[i] = self.check(el, ty.element_types[i]) return node - def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore + def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore[override] # Try to synthesize and then check if it matches the given type node, synth = self._synthesize(node) if synth != ty: From 34db487fff4ac2960a15483fa1d8aa1a626475d1 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 10:08:01 +0000 Subject: [PATCH 68/77] Improve error message wording --- guppy/checker/expr_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 01662a1b..4e27d943 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -241,7 +241,7 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: if len(node.keywords) > 0: raise GuppyError( - "Argument passing by keyword is not supported", node.keywords[0] + "Keyword arguments are not supported", node.keywords[0] ) node.func, ty = self.synthesize(node.func) From c522771db05359b56c3c90b262add503cb05de58 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 10:16:54 +0000 Subject: [PATCH 69/77] Use pattern matches --- .github/workflows/pull-request.yaml | 2 +- guppy/checker/expr_checker.py | 16 ++++---- guppy/checker/stmt_checker.py | 58 +++++++++++++++-------------- 3 files changed, 40 insertions(+), 36 deletions(-) diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 52f55df8..7638abfc 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: - python-version: [3.9] + python-version: [3.10] steps: - uses: actions/checkout@v3 diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 4e27d943..1727a3cd 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -354,10 +354,12 @@ def python_value_to_guppy_type( Returns `None` if the Python value cannot be represented in Guppy. """ - if isinstance(v, bool): - return globals.types["bool"].build(node=node) - elif isinstance(v, int): - return globals.types["int"].build(node=node) - if isinstance(v, float): - return globals.types["float"].build(node=node) - return None + match v: + case bool(): + return globals.types["bool"].build(node=node) + case int(): + return globals.types["int"].build(node=node) + case float(): + return globals.types["float"].build(node=node) + case _: + return None diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index a5503155..5622ee55 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -43,35 +43,37 @@ def _check_expr( def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: """Helper function to check assignments with patterns.""" - # Easiest case is if the LHS pattern is a single variable. - if isinstance(lhs, ast.Name): - # Check if we override an unused linear variable - x = lhs.id - if x in self.ctx.locals: - var = self.ctx.locals[x] - if var.ty.linear and var.used is None: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is not used", - var.defined_at, + match lhs: + # Easiest case is if the LHS pattern is a single variable. + case ast.Name(id=x): + # Check if we override an unused linear variable + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + self.ctx.locals[x] = Variable(x, ty, node, None) + + # The only other thing we support right now are tuples + case ast.Tuple(elts=elts): + tys = ty.element_types if isinstance(ty, TupleType) else [ty] + n, m = len(elts), len(tys) + if n != m: + raise GuppyTypeError( + f"{'Too many' if n < m else 'Not enough'} values to unpack " + f"(expected {n}, got {m})", + node, ) - self.ctx.locals[x] = Variable(x, ty, node, None) - # The only other thing we support right now are tuples - elif isinstance(lhs, ast.Tuple): - tys = ty.element_types if isinstance(ty, TupleType) else [ty] - n, m = len(lhs.elts), len(tys) - if n != m: - raise GuppyTypeError( - f"{'Too many' if n < m else 'Not enough'} values to unpack " - f"(expected {n}, got {m})", - node, - ) - for pat, el_ty in zip(lhs.elts, tys): - self._check_assign(pat, el_ty, node) - # TODO: Python also supports assignments like `[a, b] = [1, 2]` or - # `a, *b = ...`. The former would require some runtime checks but - # the latter should be easier to do (unpack and repack the rest). - else: - raise GuppyError("Assignment pattern not supported", lhs) + for pat, el_ty in zip(elts, tys): + self._check_assign(pat, el_ty, node) + + # TODO: Python also supports assignments like `[a, b] = [1, 2]` or + # `a, *b = ...`. The former would require some runtime checks but + # the latter should be easier to do (unpack and repack the rest). + case _: + raise GuppyError("Assignment pattern not supported", lhs) def visit_Assign(self, node: ast.Assign) -> ast.stmt: if len(node.targets) > 1: From f930eccf61b7d13cb4327739c18f5ba716a26c82 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 10:36:52 +0000 Subject: [PATCH 70/77] Turn assert into exception --- guppy/checker/expr_checker.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 1727a3cd..09c12774 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -85,7 +85,8 @@ def _fail( if not isinstance(actual, GuppyType): loc = loc or actual _, actual = self._synthesize(actual) - assert loc is not None + if loc is None: + raise InternalGuppyError("Failure location is required") raise GuppyTypeError( f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc ) @@ -240,9 +241,7 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: if len(node.keywords) > 0: - raise GuppyError( - "Keyword arguments are not supported", node.keywords[0] - ) + raise GuppyError("Keyword arguments are not supported", node.keywords[0]) node.func, ty = self.synthesize(node.func) # First handle direct calls of user-defined functions and extension functions From f5e60cff8255b264d99d972659e8e9e56a6c4eae Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 10:41:05 +0000 Subject: [PATCH 71/77] Specify type: ignore --- guppy/checker/cfg_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 77f2fa51..787b2ccb 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -188,7 +188,7 @@ def check_bb( checked_bb = CheckedBB( bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) ) - checked_bb.successors = [None] * len(bb.successors) # type: ignore + checked_bb.successors = [None] * len(bb.successors) # type: ignore[list-item] checked_bb.branch_pred = bb.branch_pred return checked_bb From 2f29f2f1e1cccce43383f2e7781e3d48f6da2fca Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 4 Dec 2023 10:54:41 +0000 Subject: [PATCH 72/77] Fix CI --- .github/workflows/pull-request.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 7638abfc..4ae52902 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: - python-version: [3.10] + python-version: ['3.10'] steps: - uses: actions/checkout@v3 From 325dd250f3ead8c5e8e5177e686da5dbd19496f8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 11 Dec 2023 09:45:57 +0000 Subject: [PATCH 73/77] Fix linting --- guppy/checker/stmt_checker.py | 3 ++- guppy/compiler/expr_compiler.py | 2 +- guppy/gtypes.py | 21 ++++++++++----------- guppy/hugr/hugr.py | 2 +- tests/integration/test_poly.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 0d712f0c..3959801e 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -93,7 +93,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: ) ty = type_from_ast(node.annotation, self.ctx.globals) node.value, subst = self._check_expr(node.value, ty) - assert not ty.free_vars and len(subst) == 0 # `ty` must be closed! + assert not ty.free_vars # `ty` must be closed! + assert len(subst) == 0 self._check_assign(node.target, ty, node) return node diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 05b48c65..97b08945 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -134,7 +134,7 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: """Checks if instantiating a polymorphic makes it return a row.""" if isinstance(func_ty.returns, BoundTypeVar): return_ty = inst[func_ty.returns.idx] - return isinstance(return_ty, TupleType) or isinstance(return_ty, NoneType) + return isinstance(return_ty, TupleType | NoneType) return False diff --git a/guppy/gtypes.py b/guppy/gtypes.py index e04984d3..344c8855 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -7,7 +7,6 @@ TYPE_CHECKING, ClassVar, Literal, - Optional, ) import guppy.hugr.tys as tys @@ -94,7 +93,7 @@ class BoundTypeVar(GuppyType): name: ClassVar[Literal["BoundTypeVar"]] = "BoundTypeVar" @staticmethod - def build(*rgs: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: raise NotImplementedError @property @@ -131,7 +130,7 @@ def new(cls, display_name: str, linear: bool) -> "FreeTypeVar": return FreeTypeVar(next(cls._id_generator), display_name, linear) @staticmethod - def build(*rgs: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: raise NotImplementedError @property @@ -378,7 +377,7 @@ class TypeTransformer(ABC): """Abstract base class for a type visitor that transforms types.""" @abstractmethod - def transform(self, ty: GuppyType) -> Optional[GuppyType]: + def transform(self, ty: GuppyType) -> GuppyType | None: """This method is called for each visited type. Return a transformed type or `None` to continue the recursive visit. @@ -393,7 +392,7 @@ class Substituter(TypeTransformer): def __init__(self, subst: Subst) -> None: self.subst = subst - def transform(self, ty: GuppyType) -> Optional[GuppyType]: + def transform(self, ty: GuppyType) -> GuppyType | None: if isinstance(ty, FreeTypeVar): return self.subst.get(ty, None) return None @@ -407,7 +406,7 @@ class Instantiator(TypeTransformer): def __init__(self, tys: Sequence[GuppyType]) -> None: self.tys = tys - def transform(self, ty: GuppyType) -> Optional[GuppyType]: + def transform(self, ty: GuppyType) -> GuppyType | None: if isinstance(ty, BoundTypeVar): # Instantiate if type for the index is available if ty.idx < len(self.tys): @@ -418,7 +417,7 @@ def transform(self, ty: GuppyType) -> Optional[GuppyType]: return None -def unify(s: GuppyType, t: GuppyType, subst: Optional[Subst]) -> Optional[Subst]: +def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: """Computes a most general unifier for two types. Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this @@ -441,7 +440,7 @@ def unify(s: GuppyType, t: GuppyType, subst: Optional[Subst]) -> Optional[Subst] return None -def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Optional[Subst]: +def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Subst | None: """Helper function for unification of type variables.""" if var in subst: return unify(subst[var], t, subst) @@ -455,7 +454,7 @@ def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Optional[Subst]: def type_from_ast( node: AstNode, globals: "Globals", - type_var_mapping: Optional[dict[str, BoundTypeVar]] = None, + type_var_mapping: dict[str, BoundTypeVar] | None = None, ) -> GuppyType: """Turns an AST expression into a Guppy type.""" from guppy.error import GuppyError @@ -484,8 +483,8 @@ def type_from_ast( if isinstance(v, str): try: return type_from_ast(ast.parse(v), globals, type_var_mapping) - except Exception: - raise GuppyError("Invalid Guppy type", node) + except SyntaxError: + raise GuppyError("Invalid Guppy type", node) from None raise GuppyError(f"Constant `{v}` is not a valid type", node) if isinstance(node, ast.Tuple): diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 8026822e..71846e6f 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -526,7 +526,7 @@ def add_partial( ) def add_type_apply( - self, func_port: OutPortV, args: Inst, parent: Optional[Node] = None + self, func_port: OutPortV, args: Inst, parent: Node | None = None ) -> VNode: """Adds a `TypeApply` node to the graph.""" assert isinstance(func_port.ty, FunctionType) diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index f7ea085a..2446e308 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from guppy.decorator import guppy from guppy.module import GuppyModule From f29ae3ceb68dd9ccbe8778556605d90c980ca069 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 3 Jan 2024 17:28:26 +0100 Subject: [PATCH 74/77] Improve ExprChecker docstring --- guppy/checker/expr_checker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index c54f5e68..d3f9756c 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -77,7 +77,11 @@ class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]): - """Checks an expression against a type and produces a new type-annotated AST""" + """Checks an expression against a type and produces a new type-annotated AST. + + The type may contain free variables that the checker will try to solve. Note that + the checker will fail, if some free variables cannot be inferred. + """ ctx: Context From 40bd94a31490ba9d2f926a8afa2d5ec1c5f1e4a2 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 3 Jan 2024 18:18:41 +0100 Subject: [PATCH 75/77] Fix error column offset --- guppy/error.py | 2 +- tests/error/misc_errors/return_not_annotated.err | 2 +- tests/error/misc_errors/return_not_annotated_none1.err | 2 +- tests/error/misc_errors/return_not_annotated_none2.err | 2 +- tests/error/poly_errors/define.err | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/guppy/error.py b/guppy/error.py index 530881cc..5c434901 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -144,7 +144,7 @@ def format_source_location( source_lines = source.splitlines(keepends=True) end_col_offset = loc.end_col_offset if end_col_offset is None or (loc.end_lineno and loc.end_lineno > loc.lineno): - end_col_offset = len(source_lines[loc.lineno - 1]) + end_col_offset = len(source_lines[loc.lineno - 1]) - 1 s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() s += "\n" + loc.col_offset * " " + (end_col_offset - loc.col_offset) * "^" s = textwrap.dedent(s).splitlines() diff --git a/tests/error/misc_errors/return_not_annotated.err b/tests/error/misc_errors/return_not_annotated.err index 879de215..dcce68ed 100644 --- a/tests/error/misc_errors/return_not_annotated.err +++ b/tests/error/misc_errors/return_not_annotated.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(x: bool): - ^^^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^^^^^^ GuppyError: Return type must be annotated diff --git a/tests/error/misc_errors/return_not_annotated_none1.err b/tests/error/misc_errors/return_not_annotated_none1.err index 82f22409..58d79a43 100644 --- a/tests/error/misc_errors/return_not_annotated_none1.err +++ b/tests/error/misc_errors/return_not_annotated_none1.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(): - ^^^^^^^^^^^ + ^^^^^^^^^^ GuppyError: Return type must be annotated. Try adding a `-> None` annotation. diff --git a/tests/error/misc_errors/return_not_annotated_none2.err b/tests/error/misc_errors/return_not_annotated_none2.err index 82f22409..58d79a43 100644 --- a/tests/error/misc_errors/return_not_annotated_none2.err +++ b/tests/error/misc_errors/return_not_annotated_none2.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(): - ^^^^^^^^^^^ + ^^^^^^^^^^ GuppyError: Return type must be annotated. Try adding a `-> None` annotation. diff --git a/tests/error/poly_errors/define.err b/tests/error/poly_errors/define.err index a38e2c1a..d2500379 100644 --- a/tests/error/poly_errors/define.err +++ b/tests/error/poly_errors/define.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:11 9: @guppy(module) 10: def main(x: T) -> T: - ^^^^^^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^^^^^^^^^ GuppyError: Generic function definitions are not supported yet From 876a5dec34f75f7321cdf9ddc6da240475fad061 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 9 Jan 2024 12:06:13 +0100 Subject: [PATCH 76/77] Rename FreeTypeVar to ExistentialTypeVar and free_vars to unsolved_vars --- guppy/checker/expr_checker.py | 32 ++++++++++++------------- guppy/checker/stmt_checker.py | 4 ++-- guppy/error.py | 4 ++-- guppy/gtypes.py | 44 ++++++++++++++++++----------------- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index d3f9756c..9e799399 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -34,7 +34,7 @@ ) from guppy.gtypes import ( BoolType, - FreeTypeVar, + ExistentialTypeVar, FunctionType, GuppyType, Inst, @@ -119,7 +119,7 @@ def check( resolved type variables. """ # When checking against a variable, we have to synthesize - if isinstance(ty, FreeTypeVar): + if isinstance(ty, ExistentialTypeVar): expr, syn_ty = self._synthesize(expr, allow_free_vars=False) return with_type(syn_ty, expr), {ty: syn_ty} @@ -199,7 +199,7 @@ def synthesize( if ty := get_type_opt(node): return node, ty node, ty = self.visit(node) - if ty.free_vars and not allow_free_vars: + if ty.unsolved_vars and not allow_free_vars: raise GuppyTypeError( f"Cannot infer type variable in expression of type `{ty}`", node ) @@ -355,7 +355,7 @@ def check_type_against( be quantified and the actual type may not contain free unification variables. """ assert not isinstance(exp, FunctionType) or not exp.quantified - assert not act.free_vars + assert not act.unsolved_vars # The actual type may be quantified. In that case, we have to find an instantiation # to avoid higher-rank types. @@ -374,7 +374,7 @@ def check_type_against( "rank polymorphic types are not supported)", node, ) - if subst[v].free_vars: + if subst[v].unsolved_vars: raise GuppyTypeError( f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " f"type variable `{act.quantified[i]}` with type `{subst[v]}` " @@ -382,15 +382,15 @@ def check_type_against( node, ) inst = [subst[v] for v in free_vars] - subst = {v: t for v, t in subst.items() if v in exp.free_vars} + subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements check_inst(act, inst, node) return subst, inst - # Otherwise, we know that `act` has no free type vars, so unification is trivial - assert not act.free_vars + # Otherwise, we know that `act` has no unsolved type vars, so unification is trivial + assert not act.unsolved_vars subst = unify(exp, act, {}) if subst is None: raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) @@ -434,10 +434,10 @@ def type_check_args( # If the argument check succeeded, this means that we must have found instantiations # for all unification variables occurring in the argument types - assert all(set.issubset(arg.free_vars, subst.keys()) for arg in func_ty.args) + assert all(set.issubset(arg.unsolved_vars, subst.keys()) for arg in func_ty.args) # We also have to check that we found instantiations for all vars in the return type - if not set.issubset(func_ty.returns.free_vars, subst.keys()): + if not set.issubset(func_ty.returns.unsolved_vars, subst.keys()): raise GuppyTypeInferenceError( f"Cannot infer type variable in expression of type " f"`{func_ty.returns.substitute(subst)}`", @@ -455,7 +455,7 @@ def synthesize_call( Returns an annotated argument list, the synthesized return type, and an instantiation for the quantifiers in the function type. """ - assert not func_ty.free_vars + assert not func_ty.unsolved_vars check_num_args(len(func_ty.args), len(args), node) # Replace quantified variables with free unification variables and try to infer an @@ -464,7 +464,7 @@ def synthesize_call( args, subst = type_check_args(args, unquantified, {}, ctx, node) # Success implies that the substitution is closed - assert all(not t.free_vars for t in subst.values()) + assert all(not t.unsolved_vars for t in subst.values()) inst = [subst[v] for v in free_vars] # Finally, check that the instantiation respects the linearity requirements @@ -486,7 +486,7 @@ def check_call( Returns an annotated argument list, a substitution for the free variables in the expected type, and an instantiation for the quantifiers in the function type. """ - assert not func_ty.free_vars + assert not func_ty.unsolved_vars check_num_args(len(func_ty.args), len(args), node) # When checking, we can use the information from the expected return type to infer @@ -536,7 +536,7 @@ def check_call( # Also make sure we found an instantiation for all free vars in the type we're # checking against - if not set.issubset(ty.free_vars, subst.keys()): + if not set.issubset(ty.unsolved_vars, subst.keys()): raise GuppyTypeInferenceError( f"Expected expression of type `{ty}`, got " f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", @@ -544,9 +544,9 @@ def check_call( ) # Success implies that the substitution is closed - assert all(not t.free_vars for t in subst.values()) + assert all(not t.unsolved_vars for t in subst.values()) inst = [subst[v] for v in free_vars] - subst = {v: t for v, t in subst.items() if v in ty.free_vars} + subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 3959801e..0fa46608 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -26,7 +26,7 @@ class StmtChecker(AstVisitor[BBStatement]): return_ty: GuppyType def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: - assert not return_ty.free_vars + assert not return_ty.unsolved_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -93,7 +93,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: ) ty = type_from_ast(node.annotation, self.ctx.globals) node.value, subst = self._check_expr(node.value, ty) - assert not ty.free_vars # `ty` must be closed! + assert not ty.unsolved_vars # `ty` must be closed! assert len(subst) == 0 self._check_assign(node.target, ty, node) return node diff --git a/guppy/error.py b/guppy/error.py index 5c434901..14bb5813 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -7,7 +7,7 @@ from typing import Any, TypeVar, cast from guppy.ast_util import AstNode, get_file, get_line_offset, get_source -from guppy.gtypes import BoundTypeVar, FreeTypeVar, FunctionType, GuppyType +from guppy.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType from guppy.hugr.hugr import Node, OutPortV # Whether the interpreter should exit when a Guppy error occurs @@ -128,7 +128,7 @@ def quantified(self) -> Sequence[BoundTypeVar]: raise InternalGuppyError("Tried to access unknown function type") @property - def free_vars(self) -> set[FreeTypeVar]: + def unsolved_vars(self) -> set[ExistentialTypeVar]: return set() diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 344c8855..9046e2cb 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -16,7 +16,7 @@ from guppy.checker.core import Globals -Subst = dict["FreeTypeVar", "GuppyType"] +Subst = dict["ExistentialTypeVar", "GuppyType"] Inst = Sequence["GuppyType"] @@ -30,7 +30,7 @@ class GuppyType(ABC): name: ClassVar[str] # Cache for free variables - _free_vars: set["FreeTypeVar"] = field(init=False, repr=False) + _unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False) def __post_init__(self) -> None: # Make sure that we don't have higher-rank polymorphic types @@ -43,13 +43,13 @@ def __post_init__(self) -> None: ) # Compute free variables - if isinstance(self, FreeTypeVar): + if isinstance(self, ExistentialTypeVar): vs = {self} else: vs = set() for arg in self.type_args: - vs |= arg.free_vars - object.__setattr__(self, "_free_vars", vs) + vs |= arg.unsolved_vars + object.__setattr__(self, "_unsolved_vars", vs) @staticmethod @abstractmethod @@ -75,8 +75,8 @@ def transform(self, transformer: "TypeTransformer") -> "GuppyType": pass @property - def free_vars(self) -> set["FreeTypeVar"]: - return self._free_vars + def unsolved_vars(self) -> set["ExistentialTypeVar"]: + return self._unsolved_vars def substitute(self, s: Subst) -> "GuppyType": return self.transform(Substituter(s)) @@ -111,23 +111,23 @@ def to_hugr(self) -> tys.SimpleType: @dataclass(frozen=True) -class FreeTypeVar(GuppyType): - """Free type variable, identified with a globally unique id. +class ExistentialTypeVar(GuppyType): + """Existential type variable, identified with a globally unique id. - Serves as an existential variable for unification. + Is solved during type checking. """ id: int display_name: str linear: bool = False - name: ClassVar[Literal["FreeTypeVar"]] = "FreeTypeVar" + name: ClassVar[Literal["ExistentialTypeVar"]] = "ExistentialTypeVar" _id_generator: ClassVar[Iterator[int]] = itertools.count() @classmethod - def new(cls, display_name: str, linear: bool) -> "FreeTypeVar": - return FreeTypeVar(next(cls._id_generator), display_name, linear) + def new(cls, display_name: str, linear: bool) -> "ExistentialTypeVar": + return ExistentialTypeVar(next(cls._id_generator), display_name, linear) @staticmethod def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: @@ -228,9 +228,11 @@ def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": self.arg_names, ) - def unquantified(self) -> tuple["FunctionType", Sequence[FreeTypeVar]]: + def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialTypeVar]]: """Replaces all quantified variables with free type variables.""" - inst = [FreeTypeVar.new(v.display_name, v.linear) for v in self.quantified] + inst = [ + ExistentialTypeVar.new(v.display_name, v.linear) for v in self.quantified + ] return self.instantiate(inst), inst @@ -393,7 +395,7 @@ def __init__(self, subst: Subst) -> None: self.subst = subst def transform(self, ty: GuppyType) -> GuppyType | None: - if isinstance(ty, FreeTypeVar): + if isinstance(ty, ExistentialTypeVar): return self.subst.get(ty, None) return None @@ -427,9 +429,9 @@ def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: return None if s == t: return subst - if isinstance(s, FreeTypeVar): + if isinstance(s, ExistentialTypeVar): return _unify_var(s, t, subst) - if isinstance(t, FreeTypeVar): + if isinstance(t, ExistentialTypeVar): return _unify_var(t, s, subst) if type(s) == type(t): sargs, targs = list(s.type_args), list(t.type_args) @@ -440,13 +442,13 @@ def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: return None -def _unify_var(var: FreeTypeVar, t: GuppyType, subst: Subst) -> Subst | None: +def _unify_var(var: ExistentialTypeVar, t: GuppyType, subst: Subst) -> Subst | None: """Helper function for unification of type variables.""" if var in subst: return unify(subst[var], t, subst) - if isinstance(t, FreeTypeVar) and t in subst: + if isinstance(t, ExistentialTypeVar) and t in subst: return unify(var, subst[t], subst) - if var in t.free_vars: + if var in t.unsolved_vars: return None return {var: t, **subst} From 618f20bf1e971e292bc52fcb9cd43a57d036d4b2 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 9 Jan 2024 12:13:30 +0100 Subject: [PATCH 77/77] Clarify TypeApplication docstring --- guppy/hugr/ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index b44cea9b..02b393e3 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -485,7 +485,11 @@ class TypeApply(LeafOp): class TypeApplication(BaseModel): """Records details of an application of a PolyFuncType to some TypeArgs and the - result (a less-, but still potentially-, polymorphic type).""" + result (a less-, but still potentially-, polymorphic type). + + Note that Guppy only generates full type applications, where the result is a + monomorphic type. Partial type applications are not used by Guppy. + """ input: PolyFuncType args: list[tys.TypeArg]