diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 77603220..bfe6bae9 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -118,7 +118,9 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None: annotate_location(node, source, file, line_offset) -def annotate_location(node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True) -> None: +def annotate_location( + node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True +) -> None: setattr(node, "line_offset", line_offset) setattr(node, "file", file) setattr(node, "source", source) diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py index 6bd1e4b0..2648f696 100644 --- a/guppy/cfg/cfg.py +++ b/guppy/cfg/cfg.py @@ -25,7 +25,9 @@ class BaseCFG(Generic[T]): ass_before: Result[DefAssignmentDomain] maybe_ass_before: Result[MaybeAssignmentDomain] - def __init__(self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None): + def __init__( + self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None + ): self.bbs = bbs if entry_bb: self.entry_bb = entry_bb diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 52f94d5a..fdcfb54a 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -49,7 +49,9 @@ def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None: self.output_ty = output_ty -def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedCFG: +def check_cfg( + cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals +) -> CheckedCFG: """Type checks a control-flow graph. Annotates the basic blocks with input and output type signatures and removes @@ -61,18 +63,24 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) # We start by compiling the entry BB checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) - checked_cfg.entry_bb = check_bb(cfg.entry_bb, checked_cfg, inputs, return_ty, globals) + checked_cfg.entry_bb = check_bb( + cfg.entry_bb, checked_cfg, inputs, return_ty, globals + ) compiled = {cfg.entry_bb: checked_cfg.entry_bb} # Visit all control-flow edges in BFS order. We can't just do a normal loop over # all BBs since the input types for a BB are computed by checking a predecessor. # We do BFS instead of DFS to get a better error ordering. queue = collections.deque( - (checked_cfg.entry_bb, i, succ) for i, succ in enumerate(cfg.entry_bb.successors) + (checked_cfg.entry_bb, i, succ) + for i, succ in enumerate(cfg.entry_bb.successors) ) while len(queue) > 0: pred, num_output, bb = queue.popleft() - input_row = [Variable(v.name, v.ty, v.defined_at, None) for v in pred.sig.output_rows[num_output]] + input_row = [ + Variable(v.name, v.ty, v.defined_at, None) + for v in pred.sig.output_rows[num_output] + ] if bb in compiled: # If the BB was already compiled, we just have to check that the signatures @@ -81,9 +89,7 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) else: # Otherwise, check the BB and enqueue its successors checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) - queue += [ - (checked_bb, i, succ) for i, succ in enumerate(bb.successors) - ] + queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)] compiled[bb] = checked_bb # Link up BBs in the checked CFG @@ -94,11 +100,19 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} - checked_cfg.maybe_ass_before = {compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs} + checked_cfg.maybe_ass_before = { + compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs + } return checked_cfg -def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedBB: +def check_bb( + bb: BB, + checked_cfg: CheckedCFG, + inputs: VarRow, + return_ty: GuppyType, + globals: Globals, +) -> CheckedBB: cfg = bb.cfg # For the entry BB we have to separately check that all used variables are @@ -132,9 +146,7 @@ def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyTy f"Variable `{x}` is not defined on all control-flow paths.", use_bb.vars.used[x], ) - raise GuppyError( - f"Variable `{x}` is not defined", use_bb.vars.used[x] - ) + raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x]) # We have to check that used linear variables are not being outputted if x in ctx.locals: @@ -166,7 +178,9 @@ def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyTy ] # Also prepare the successor list so we can fill it in later - checked_bb = CheckedBB(bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs)) + checked_bb = CheckedBB( + bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) + ) checked_bb.successors = [None] * len(bb.successors) # type: ignore checked_bb.branch_pred = bb.branch_pred return checked_bb @@ -193,9 +207,7 @@ def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: v1, v2 = v2, v1 # We shouldn't mention temporary variables (starting with `%`) # in error messages: - ident = ( - "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" - ) + ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" raise GuppyError( f"{ident} can refer to different types: " f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", diff --git a/guppy/checker/core.py b/guppy/checker/core.py index bcbe47f3..77cde03e 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -4,8 +4,14 @@ from typing import NamedTuple, Optional, Union from guppy.ast_util import AstNode -from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType, NoneType, \ - BoolType +from guppy.guppy_types import ( + GuppyType, + FunctionType, + TupleType, + SumType, + NoneType, + BoolType, +) @dataclass @@ -25,11 +31,15 @@ class CallableVariable(ABC, Variable): ty: FunctionType @abstractmethod - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context") -> ast.expr: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" + ) -> ast.expr: """Checks the return type of a function call against a given type.""" @abstractmethod - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: "Context" + ) -> tuple[ast.expr, GuppyType]: """Synthesizes the return type of a function call.""" diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index f62fe5a7..a5297f57 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -52,7 +52,12 @@ def __init__(self, ctx: Context) -> None: self.ctx = ctx self._kind = "expression" - def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Optional[AstNode] = None) -> NoReturn: + def _fail( + self, + expected: GuppyType, + actual: Union[ast.expr, GuppyType], + loc: Optional[AstNode] = None, + ) -> NoReturn: """Raises a type error indicating that the type doesn't match.""" if not isinstance(actual, GuppyType): loc = loc or actual @@ -62,7 +67,9 @@ def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Op f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc ) - def check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + def check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: """Checks an expression against a type. Returns a new desugared expression with type annotations. @@ -108,7 +115,9 @@ def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: node, ty = self.visit(node) return with_type(ty, node), ty - def _check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + def _check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: """Checks an expression against a given type""" return ExprChecker(self.ctx).check(expr, ty, kind) @@ -218,7 +227,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: node.func, ty = self.synthesize(node.func) # First handle direct calls of user-defined functions and extension functions - if isinstance(node.func, GlobalName) and isinstance(node.func.value, CallableVariable): + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): return node.func.value.synthesize_call(node.args, node, self.ctx) # Otherwise, it must be a function as a higher-order value @@ -263,7 +274,9 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: ) -def synthesize_call(func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[list[ast.expr], GuppyType]: +def synthesize_call( + func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context +) -> tuple[list[ast.expr], GuppyType]: """Synthesizes the return type of a function call""" check_num_args(len(func_ty.args), len(args), node) for i, arg in enumerate(args): @@ -271,7 +284,13 @@ def synthesize_call(func_ty: FunctionType, args: list[ast.expr], node: AstNode, return args, func_ty.returns -def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> list[ast.expr]: +def check_call( + func_ty: FunctionType, + args: list[ast.expr], + ty: GuppyType, + node: AstNode, + ctx: Context, +) -> list[ast.expr]: """Checks the return type of a function call against a given type""" args, return_ty = synthesize_call(func_ty, args, node, ctx) if return_ty != ty: @@ -281,7 +300,9 @@ def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node: return args -def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr, GuppyType]: +def to_bool( + node: ast.expr, node_ty: GuppyType, ctx: Context +) -> tuple[ast.expr, GuppyType]: """Tries to turn a node into a bool""" if isinstance(node_ty, BoolType): return node, node_ty @@ -299,12 +320,14 @@ def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr, if not isinstance(return_ty, BoolType): raise GuppyTypeError( f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", - node + node, ) return call, return_ty -def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Optional[GuppyType]: +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals +) -> Optional[GuppyType]: """Turns a primitive Python value into a Guppy type. Returns `None` if the Python value cannot be represented in Guppy. @@ -316,4 +339,3 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Opti if isinstance(v, float): return globals.types["float"].build(node=node) return None - diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 13109781..facd10b6 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -15,20 +15,27 @@ @dataclass class DefinedFunction(CallableVariable): """A user-defined function""" + ty: FunctionType defined_at: ast.FunctionDef @staticmethod - def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DefinedFunction": + def from_ast( + func_def: ast.FunctionDef, name: str, globals: Globals + ) -> "DefinedFunction": ty = check_signature(func_def, globals) return DefinedFunction(name, ty, func_def, None) - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> GlobalCall: # Use default implementation from the expression checker args = check_call(self.ty, args, ty, node, ctx) return GlobalCall(func=self, args=args) - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty = synthesize_call(self.ty, args, node, ctx) return GlobalCall(func=self, args=args), ty @@ -57,7 +64,9 @@ def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFun return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) -def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) -> CheckedNestedFunctionDef: +def check_nested_func_def( + func_def: NestedFunctionDef, bb: BB, ctx: Context +) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" func_ty = check_signature(func_def, ctx.globals) assert func_ty.arg_names is not None @@ -102,7 +111,10 @@ def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) -> ) # Construct inputs for checking the body CFG - inputs = list(captured.values()) + [Variable(x, ty, func_def.args.args[i], None) for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args))] + inputs = list(captured.values()) + [ + Variable(x, ty, func_def.args.args[i], None) + for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) + ] globals = ctx.globals # Check if the body contains a recursive occurrence of the function name diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 96f891cb..5030c649 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -11,7 +11,6 @@ class StmtChecker(AstVisitor[BBStatement]): - ctx: Context bb: BB return_ty: GuppyType @@ -27,7 +26,9 @@ def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: return ExprSynthesizer(self.ctx).synthesize(node) - def _check_expr(self, node: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr: + def _check_expr( + self, node: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: return ExprChecker(self.ctx).check(node, ty, kind) def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: @@ -74,14 +75,18 @@ def visit_Assign(self, node: ast.Assign) -> ast.stmt: def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: if node.value is None: - raise GuppyError("Variable declaration is not supported. Assignment is required", node) + raise GuppyError( + "Variable declaration is not supported. Assignment is required", node + ) ty = type_from_ast(node.annotation, self.ctx.globals) node.value = self._check_expr(node.value, ty) self._check_assign(node.target, ty, node) return node def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: - bin_op = with_loc(node, ast.BinOp(left=node.target, op=node.op, right=node.value)) + bin_op = with_loc( + node, ast.BinOp(left=node.target, op=node.op, right=node.value) + ) assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) return self.visit_Assign(assign) @@ -103,7 +108,9 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: from guppy.checker.func_checker import check_nested_func_def func_def = check_nested_func_def(node, self.bb, self.ctx) - self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def, None) + self.ctx.locals[func_def.name] = Variable( + func_def.name, func_def.ty, func_def, None + ) return func_def def visit_If(self, node: ast.If) -> None: diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index bd31be7e..7446af8e 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -3,15 +3,22 @@ from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature from guppy.checker.core import Variable -from guppy.compiler.core import CompiledGlobals, is_return_var, DFContainer, return_var, \ - PortVariable +from guppy.compiler.core import ( + CompiledGlobals, + is_return_var, + DFContainer, + return_var, + PortVariable, +) from guppy.compiler.expr_compiler import ExprCompiler from guppy.compiler.stmt_compiler import StmtCompiler from guppy.guppy_types import TupleType, SumType, type_to_row from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV -def compile_cfg(cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals) -> None: +def compile_cfg( + cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals +) -> None: """Compiles a CFG to Hugr.""" insert_return_vars(cfg) @@ -23,7 +30,9 @@ def compile_cfg(cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlo graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None)) -def compile_bb(bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals) -> CFNode: +def compile_bb( + bb: CheckedBB, graph: Hugr, parent: Node, globals: CompiledGlobals +) -> CFNode: """Compiles a single basic block to Hugr.""" inputs = sort_vars(bb.sig.input_row) @@ -100,7 +109,10 @@ def insert_return_vars(cfg: CheckedCFG) -> None: `%ret0`, `%ret1`, etc. We update the exit BB signature to make sure they are correctly outputted. """ - return_vars = [Variable(return_var(i), ty, None, None) for i, ty in enumerate(type_to_row(cfg.output_ty))] + return_vars = [ + Variable(return_var(i), ty, None, None) + for i, ty in enumerate(type_to_row(cfg.output_ty)) + ] # Before patching, the exit BB shouldn't take any inputs assert len(cfg.exit_bb.sig.input_row) == 0 cfg.exit_bb.sig = Signature(return_vars, cfg.exit_bb.sig.output_rows) @@ -123,14 +135,13 @@ def choose_vars_for_pred( assert len(pred.ty.element_types) == len(output_vars) tuples = [ graph.add_make_tuple( - inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], parent=dfg.node + inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], + parent=dfg.node, ).out_port(0) for vs in output_vars ] tys = [t.ty for t in tuples] - conditional = graph.add_conditional( - cond_input=pred, inputs=tuples, parent=dfg.node - ) + conditional = graph.add_conditional(cond_input=pred, inputs=tuples, parent=dfg.node) for i, ty in enumerate(tys): case = graph.add_case(conditional) inp = graph.add_input(output_tys=tys, parent=case).out_port(i) @@ -158,4 +169,3 @@ def sort_vars(row: VarRow) -> list[Variable]: This determines the order in which they are outputted from a BB. """ return sorted(row, key=functools.cmp_to_key(compare_var)) - diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 9d93c54a..2e616e5a 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -17,7 +17,13 @@ class PortVariable(Variable): port: OutPortV - def __init__(self, name: str, port: OutPortV, defined_at: Optional[AstNode], used: Optional[AstNode] = None): + def __init__( + self, + name: str, + port: OutPortV, + defined_at: Optional[AstNode], + used: Optional[AstNode] = None, + ) -> None: super().__init__(name, port.ty, defined_at, used) object.__setattr__(self, "port", port) @@ -26,7 +32,9 @@ class CompiledVariable(ABC, Variable): """Abstract base class for compiled global module-level variables.""" @abstractmethod - def load(self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode) -> OutPortV: + def load( + self, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", node: AstNode + ) -> OutPortV: """Loads the variable as a value into a local dataflow graph.""" diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index 9a6e30ca..ee032cc9 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -3,8 +3,12 @@ from guppy.ast_util import AstNode from guppy.checker.func_checker import CheckedFunction, DefinedFunction from guppy.compiler.cfg_compiler import compile_cfg -from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer, \ - PortVariable +from guppy.compiler.core import ( + CompiledFunction, + CompiledGlobals, + DFContainer, + PortVariable, +) from guppy.guppy_types import type_to_row, FunctionType from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode from guppy.nodes import CheckedNestedFunctionDef @@ -14,15 +18,29 @@ class CompiledFunctionDef(DefinedFunction, CompiledFunction): node: DFContainingVNode - def load(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) - def compile_call(self, args: list[OutPortV], dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> list[OutPortV]: + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] -def compile_global_func_def(func: CheckedFunction, def_node: DFContainingVNode, graph: Hugr, globals: CompiledGlobals) -> CompiledFunctionDef: +def compile_global_func_def( + func: CheckedFunction, + def_node: DFContainingVNode, + graph: Hugr, + globals: CompiledGlobals, +) -> CompiledFunctionDef: """Compiles a top-level function definition to Hugr.""" def_input = graph.add_input(parent=def_node) cfg_node = graph.add_cfg( @@ -31,12 +49,20 @@ def compile_global_func_def(func: CheckedFunction, def_node: DFContainingVNode, compile_cfg(func.cfg, graph, cfg_node, globals) # Add output node for the cfg - graph.add_output(inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], parent=def_node) + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) -def compile_local_func_def(func: CheckedNestedFunctionDef, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals) -> PortVariable: +def compile_local_func_def( + func: CheckedNestedFunctionDef, + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, +) -> PortVariable: """Compiles a local (nested) function definition to Hugr.""" assert func.ty.arg_names is not None @@ -59,24 +85,32 @@ def compile_local_func_def(func: CheckedNestedFunctionDef, dfg: DFContainer, gra # variable if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]: loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0) - partial = graph.add_partial(loaded, [def_input.out_port(i) for i in range(len(captured))], def_node) + partial = graph.add_partial( + loaded, [def_input.out_port(i) for i in range(len(captured))], def_node + ) input_ports.append(partial.out_port(0)) func.cfg.input_tys.append(func.ty) else: # Otherwise, we treat the function like a normal global variable - globals = globals | {func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node)} + globals = globals | { + func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node) + } # Compile the CFG cfg_node = graph.add_cfg(def_node, inputs=input_ports) compile_cfg(func.cfg, graph, cfg_node, globals) # Add output node for the cfg - graph.add_output(inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], parent=def_node) + graph.add_output( + inputs=[cfg_node.add_out_port(ty) for ty in type_to_row(func.cfg.output_ty)], + parent=def_node, + ) # Finally, load the function into the local data-flow graph loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) if len(captured) > 0: - loaded = graph.add_partial(loaded, [dfg[v.name].port for v in captured], dfg.node).out_port(0) + loaded = graph.add_partial( + loaded, [dfg[v.name].port for v in captured], dfg.node + ).out_port(0) return PortVariable(func.name, loaded, func) - diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index a4269c0e..61b80dff 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -3,8 +3,13 @@ from guppy.ast_util import AstVisitor from guppy.checker.cfg_checker import CheckedBB -from guppy.compiler.core import CompilerBase, DFContainer, CompiledGlobals, \ - PortVariable, return_var +from guppy.compiler.core import ( + CompilerBase, + DFContainer, + CompiledGlobals, + PortVariable, + return_var, +) from guppy.compiler.expr_compiler import ExprCompiler from guppy.error import InternalGuppyError from guppy.guppy_types import TupleType @@ -86,6 +91,6 @@ def visit_Return(self, node: ast.Return) -> None: def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: from guppy.compiler.func_compiler import compile_local_func_def - self.dfg[node.name] = compile_local_func_def(node, self.dfg, self.graph, self.globals) - - + self.dfg[node.name] = compile_local_func_def( + node, self.dfg, self.graph, self.globals + ) diff --git a/guppy/custom.py b/guppy/custom.py index c888746f..0c03f2a8 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -8,8 +8,12 @@ from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.compiler.expr_compiler import ExprCompiler -from guppy.error import GuppyError, InternalGuppyError, UnknownFunctionType, \ - GuppyTypeError +from guppy.error import ( + GuppyError, + InternalGuppyError, + UnknownFunctionType, + GuppyTypeError, +) from guppy.guppy_types import GuppyType, FunctionType, type_to_row, TupleType from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode @@ -31,7 +35,15 @@ class CustomFunction(CompiledFunction): _ty: Optional[FunctionType] = None _defined: dict[Node, DFContainingVNode] = {} - def __init__(self, name: str, defined_at: Optional[ast.FunctionDef], compiler: "CustomCallCompiler", checker: "CustomCallChecker", higher_order_value: bool = True, ty: Optional[FunctionType] = None): + def __init__( + self, + name: str, + defined_at: Optional[ast.FunctionDef], + compiler: "CustomCallCompiler", + checker: "CustomCallChecker", + higher_order_value: bool = True, + ty: Optional[FunctionType] = None, + ): self.name = name self.defined_at = defined_at self.higher_order_value = higher_order_value @@ -73,11 +85,15 @@ def check_type(self, globals: Globals) -> None: if self.call_checker is None or self.higher_order_value: raise err - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> ast.expr: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> ast.expr: self.call_checker._setup(ctx, node, self) return with_type(ty, with_loc(node, self.call_checker.check(args, ty))) - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: "Context" + ) -> tuple[ast.expr, GuppyType]: self.call_checker._setup(ctx, node, self) new_node, ty = self.call_checker.synthesize(args) return with_type(ty, with_loc(node, new_node)), ty @@ -93,7 +109,9 @@ def compile_call( self.call_compiler._setup(dfg, graph, globals, node) return self.call_compiler.compile(args) - def load(self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + def load( + self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: """Loads the custom function as a value into a local dataflow graph. This will place a `FunctionDef` node into the Hugr module if one for this @@ -175,7 +193,9 @@ class CustomCallCompiler(ABC): globals: CompiledGlobals node: AstNode - def _setup(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> None: + def _setup( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> None: self.dfg = dfg self.graph = graph self.globals = globals @@ -226,6 +246,5 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: class NoopCompiler(CustomCallCompiler): - def compile(self, args: list[OutPortV]) -> list[OutPortV]: return args diff --git a/guppy/declared.py b/guppy/declared.py index 167f9abd..257c74ad 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -20,18 +20,26 @@ class DeclaredFunction(CompiledFunction): node: Optional[VNode] = None @staticmethod - def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DeclaredFunction": + def from_ast( + func_def: ast.FunctionDef, name: str, globals: Globals + ) -> "DeclaredFunction": ty = check_signature(func_def, globals) if not has_empty_body(func_def): - raise GuppyError("Body of function declaration must be empty", func_def.body[0]) + raise GuppyError( + "Body of function declaration must be empty", func_def.body[0] + ) return DeclaredFunction(name, ty, func_def, None) - def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall: + def check_call( + self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + ) -> GlobalCall: # Use default implementation from the expression checker args = check_call(self.ty, args, ty, node, ctx) return GlobalCall(func=self, args=args) - def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]: + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty = synthesize_call(self.ty, args, node, ctx) return GlobalCall(func=self, args=args), ty @@ -39,11 +47,20 @@ def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) - def load(self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> OutPortV: + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: assert self.node is not None return graph.add_load_constant(self.node.out_port(0), dfg.node).out_port(0) - def compile_call(self, args: list[OutPortV], dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode) -> list[OutPortV]: + def compile_call( + self, + args: list[OutPortV], + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: assert self.node is not None call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/decorator.py b/guppy/decorator.py index 894a94fa..027e032f 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -3,8 +3,14 @@ from typing import Optional, Union, Callable, Any from guppy.ast_util import AstNode, has_empty_body -from guppy.custom import CustomFunction, OpCompiler, DefaultCallChecker, \ - CustomCallCompiler, CustomCallChecker, DefaultCallCompiler +from guppy.custom import ( + CustomFunction, + OpCompiler, + DefaultCallChecker, + CustomCallCompiler, + CustomCallChecker, + DefaultCallCompiler, +) from guppy.error import GuppyError, pretty_errors from guppy.guppy_types import GuppyType from guppy.hugr import tys, ops @@ -29,13 +35,16 @@ def set_module(self, module: GuppyModule) -> None: self._module = module @pretty_errors - def __call__(self, arg: Union[PyFunc, GuppyModule]) -> Union[Optional[Hugr], FuncDecorator]: + def __call__( + self, arg: Union[PyFunc, GuppyModule] + ) -> Union[Optional[Hugr], FuncDecorator]: """Decorator to annotate Python functions as Guppy code. Optionally, the `GuppyModule` in which the function should be placed can be passed to the decorator. """ if isinstance(arg, GuppyModule): + def dec(f: Callable[..., Any]) -> Callable[..., Any]: assert isinstance(arg, GuppyModule) arg.register_func_def(f) @@ -66,7 +75,13 @@ def dec(c: type) -> type: return dec @pretty_errors - def type(self, module: GuppyModule, hugr_ty: tys.SimpleType, name: str = "", linear: bool = False) -> ClassDecorator: + def type( + self, + module: GuppyModule, + hugr_ty: tys.SimpleType, + name: str = "", + linear: bool = False, + ) -> ClassDecorator: """Decorator to annotate a class definitions as Guppy types. Requires the static Hugr translation of the type. Additionally, the type can be @@ -111,7 +126,14 @@ def __str__(self) -> str: return dec @pretty_errors - def custom(self, module: GuppyModule, compiler: Optional[CustomCallCompiler] = None, checker: Optional[CustomCallChecker] = None, higher_order_value: bool = True, name: str = "") -> CustomFuncDecorator: + def custom( + self, + module: GuppyModule, + compiler: Optional[CustomCallCompiler] = None, + checker: Optional[CustomCallChecker] = None, + higher_order_value: bool = True, + name: str = "", + ) -> CustomFuncDecorator: """Decorator to add custom typing or compilation behaviour to function decls. Optionally, usage of the function as a higher-order value can be disabled. In @@ -124,22 +146,36 @@ def dec(f: PyFunc) -> CustomFunction: if not has_empty_body(func_ast): raise GuppyError( "Body of custom function declaration must be empty", - func_ast.body[0] + func_ast.body[0], ) call_checker = checker or DefaultCallChecker() - func = CustomFunction(name or func_ast.name, func_ast, compiler or DefaultCallCompiler(), call_checker, higher_order_value) + func = CustomFunction( + name or func_ast.name, + func_ast, + compiler or DefaultCallCompiler(), + call_checker, + higher_order_value, + ) call_checker.func = func module.register_custom_func(func) return func return dec - def hugr_op(self, module: GuppyModule, op: ops.OpType, checker: Optional[CustomCallChecker] = None, higher_order_value: bool = True, name: str = "") -> CustomFuncDecorator: + def hugr_op( + self, + module: GuppyModule, + op: ops.OpType, + checker: Optional[CustomCallChecker] = None, + higher_order_value: bool = True, + name: str = "", + ) -> CustomFuncDecorator: """Decorator to annotate function declarations as HUGR ops.""" return self.custom(module, OpCompiler(op), checker, higher_order_value, name) def declare(self, module: GuppyModule) -> FuncDecorator: """Decorator to declare functions""" + def dec(f: Callable[..., Any]) -> Callable[..., Any]: module.register_func_decl(f) @@ -154,6 +190,4 @@ def dummy(*args: Any, **kwargs: Any) -> Any: return dec - - guppy = _Guppy() diff --git a/guppy/error.py b/guppy/error.py index 81dfec17..77a08cea 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -176,4 +176,3 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return None return cast(FuncT, wrapped) - diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index fa34bbef..d9f916e6 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -12,7 +12,9 @@ GuppyType, TupleType, FunctionType, - SumType, type_to_row, row_to_type, + SumType, + type_to_row, + row_to_type, ) from guppy.hugr import val @@ -483,7 +485,11 @@ def add_call( """Adds a `Call` node to the graph.""" assert isinstance(def_port.ty, FunctionType) return self.add_node( - ops.Call(), None, list(type_to_row(def_port.ty.returns)), parent, args + [def_port] + ops.Call(), + None, + list(type_to_row(def_port.ty.returns)), + parent, + args + [def_port], ) def add_indirect_call( @@ -628,7 +634,9 @@ def remove_dummy_nodes(self) -> "Hugr": for n in list(self.nodes()): if isinstance(n, VNode) and isinstance(n.op, ops.DummyOp): name = n.op.name - fun_ty = FunctionType(list(n.in_port_types), row_to_type(n.out_port_types)) + fun_ty = FunctionType( + list(n.in_port_types), row_to_type(n.out_port_types) + ) if name in used_names: used_names[name] += 1 name = f"{name}${used_names[name]}" diff --git a/guppy/module.py b/guppy/module.py index 2b501291..a88a5dc5 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -61,6 +61,7 @@ def __init__(self, name: str, import_builtins: bool = True): # Import builtin module if import_builtins: import guppy.prelude.builtins as builtins + self.load(builtins) def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: @@ -72,7 +73,9 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: m.compile() # For now, we can only import custom functions - if any(not isinstance(v, CustomFunction) for v in m._compiled_globals.values()): + if any( + not isinstance(v, CustomFunction) for v in m._compiled_globals.values() + ): raise GuppyError( "Importing modules with defined functions is not supported yet" ) @@ -84,14 +87,18 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: if isinstance(val, GuppyModule): self.load(val) - def register_func_def(self, f: PyFunc, instance: Optional[type[GuppyType]] = None) -> None: + def register_func_def( + self, f: PyFunc, instance: Optional[type[GuppyType]] = None + ) -> None: """Registers a Python function definition as belonging to this Guppy module.""" self._check_not_yet_compiled() func_ast = parse_py_func(f) if self._instance_func_buffer is not None: self._instance_func_buffer[func_ast.name] = f else: - name = qualified_name(instance, func_ast.name) if instance else func_ast.name + name = ( + qualified_name(instance, func_ast.name) if instance else func_ast.name + ) self._check_name_available(name, func_ast) self._func_defs[name] = func_ast @@ -102,7 +109,9 @@ def register_func_decl(self, f: PyFunc) -> None: self._check_name_available(func_ast.name, func_ast) self._func_decls[func_ast.name] = func_ast - def register_custom_func(self, func: CustomFunction, instance: Optional[type[GuppyType]] = None) -> None: + def register_custom_func( + self, func: CustomFunction, instance: Optional[type[GuppyType]] = None + ) -> None: """Registers a custom function as belonging to this Guppy module.""" self._check_not_yet_compiled() if self._instance_func_buffer is not None: @@ -153,7 +162,10 @@ def compile(self) -> Optional[Hugr]: self._globals.values.update(defined_funcs) # Type check function definitions - checked = {x: check_global_func_def(f, self._imported_globals | self._globals) for x, f in defined_funcs.items()} + checked = { + x: check_global_func_def(f, self._imported_globals | self._globals) + for x, f in defined_funcs.items() + } # Add declared functions to the graph graph = Hugr(self.name) @@ -163,14 +175,23 @@ def compile(self) -> Optional[Hugr]: # Prepare `FunctionDef` nodes for all function definitions def_nodes = {x: graph.add_def(f.ty, module_node, x) for x, f in checked.items()} - self._compiled_globals |= self._custom_funcs | declared_funcs | { - x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) - for x, f in checked.items() - } + self._compiled_globals |= ( + self._custom_funcs + | declared_funcs + | { + x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) + for x, f in checked.items() + } + ) # Compile function definitions to Hugr for x, f in checked.items(): - compile_global_func_def(f, def_nodes[x], graph, self._imported_compiled_globals | self._compiled_globals) + compile_global_func_def( + f, + def_nodes[x], + graph, + self._imported_compiled_globals | self._compiled_globals, + ) self._compiled = True return graph diff --git a/guppy/nodes.py b/guppy/nodes.py index a84f9de0..56795f1d 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -14,9 +14,7 @@ class LocalName(ast.expr): id: str - _fields = ( - 'id', - ) + _fields = ("id",) class GlobalName(ast.expr): @@ -24,8 +22,8 @@ class GlobalName(ast.expr): value: "Variable" _fields = ( - 'id', - 'value', + "id", + "value", ) @@ -34,8 +32,8 @@ class LocalCall(ast.expr): args: list[ast.expr] _fields = ( - 'func', - 'args', + "func", + "args", ) @@ -46,8 +44,8 @@ class GlobalCall(ast.expr): # Later: Inferred type args _fields = ( - 'func', - 'args', + "func", + "args", ) @@ -66,7 +64,14 @@ class CheckedNestedFunctionDef(ast.FunctionDef): ty: FunctionType captured: Mapping[str, "Variable"] - def __init__(self, cfg: "CheckedCFG", ty: FunctionType, captured: Mapping[str, "Variable"], *args: Any, **kwargs: Any) -> None: + def __init__( + self, + cfg: "CheckedCFG", + ty: FunctionType, + captured: Mapping[str, "Variable"], + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.cfg = cfg self.ty = ty diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 7c410f7c..c7b257e6 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -6,8 +6,12 @@ from guppy.ast_util import with_type, AstNode, with_loc, get_type from guppy.checker.core import Context, CallableVariable from guppy.checker.expr_checker import ExprSynthesizer, check_num_args -from guppy.custom import CustomCallChecker, DefaultCallChecker, CustomFunction, \ - CustomCallCompiler +from guppy.custom import ( + CustomCallChecker, + DefaultCallChecker, + CustomFunction, + CustomCallCompiler, +) from guppy.error import GuppyTypeError from guppy.guppy_types import GuppyType, type_to_row, FunctionType, BoolType from guppy.hugr import ops, tys, val @@ -69,7 +73,9 @@ def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) -def int_op(op_name: str, ext: str = "arithmetic.int", num_params: int = 1) -> ops.OpType: +def int_op( + op_name: str, ext: str = "arithmetic.int", num_params: int = 1 +) -> ops.OpType: """Utility method to create Hugr integer arithmetic ops.""" return ops.CustomOp( extension=ext, @@ -92,8 +98,12 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: for i in range(len(args)): args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) if isinstance(ty, self.ctx.globals.types["int"]): - call = with_loc(self.node, GlobalCall(func=Int.__float__, args=[args[i]])) - args[i] = with_type(self.ctx.globals.types["float"].build(node=self.node), call) + call = with_loc( + self.node, GlobalCall(func=Int.__float__, args=[args[i]]) + ) + args[i] = with_type( + self.ctx.globals.types["float"].build(node=self.node), call + ) return super().synthesize(args) @@ -148,7 +158,9 @@ def __init__(self, dunder_name: str, num_args: int = 1): self.dunder_name = dunder_name self.num_args = num_args - def _get_func(self, args: list[ast.expr]) -> tuple[list[ast.expr], CallableVariable]: + def _get_func( + self, args: list[ast.expr] + ) -> tuple[list[ast.expr], CallableVariable]: check_num_args(self.num_args, len(args), self.node) fst, *rest = args fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) @@ -193,9 +205,15 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # Compile `truediv` using float arithmetic [left, right] = args - [left] = Int.__float__.compile_call([left], self.dfg, self.graph, self.globals, self.node) - [right] = Int.__float__.compile_call([right], self.dfg, self.graph, self.globals, self.node) - return Float.__truediv__.compile_call([left, right], self.dfg, self.graph, self.globals, self.node) + [left] = Int.__float__.compile_call( + [left], self.dfg, self.graph, self.globals, self.node + ) + [right] = Int.__float__.compile_call( + [right], self.dfg, self.graph, self.globals, self.node + ) + return Float.__truediv__.compile_call( + [left, right], self.dfg, self.graph, self.globals, self.node + ) class FloatBoolCompiler(CustomCallCompiler): @@ -205,7 +223,9 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: from .builtins import Float # We have: bool(x) = (x != 0.0) - zero_const = self.graph.add_constant(float_value(0.0), get_type(self.node), self.dfg.node) + zero_const = self.graph.add_constant( + float_value(0.0), get_type(self.node), self.dfg.node + ) zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) return Float.__ne__.compile_call( [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node @@ -260,4 +280,4 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: [mod] = Float.__mod__.compile_call( args, self.dfg, self.graph, self.globals, self.node ) - return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] \ No newline at end of file + return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index c684aac6..d83d4633 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -7,10 +7,23 @@ from guppy.guppy_types import BoolType from guppy.hugr import tys from guppy.module import GuppyModule -from guppy.prelude._internal import logic_op, int_op, hugr_int_type, hugr_float_type, \ - float_op, CoercingChecker, ReversingChecker, IntTruedivCompiler, FloatBoolCompiler, \ - FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, \ - NotImplementedCompiler, DunderChecker, CallableChecker +from guppy.prelude._internal import ( + logic_op, + int_op, + hugr_int_type, + hugr_float_type, + float_op, + CoercingChecker, + ReversingChecker, + IntTruedivCompiler, + FloatBoolCompiler, + FloatDivmodCompiler, + FloatFloordivCompiler, + FloatModCompiler, + NotImplementedCompiler, + DunderChecker, + CallableChecker, +) builtins = GuppyModule("builtins", import_builtins=False) @@ -18,7 +31,6 @@ @guppy.extend_type(builtins, BoolType) class Bool: - @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) def __and__(self: bool, other: bool) -> bool: ... @@ -38,7 +50,6 @@ def __or__(self: bool, other: bool) -> bool: @guppy.type(builtins, hugr_int_type, name="int") class Int: - @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) def __abs__(self: int) -> int: ... @@ -131,7 +142,9 @@ def __or__(self: int, other: int) -> int: def __pos__(self: int) -> int: ... - @guppy.custom(builtins, NotImplementedCompiler("ipow"), DefaultCallChecker()) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("ipow"), DefaultCallChecker() + ) # TODO def __pow__(self: int, other: int) -> int: ... @@ -143,19 +156,29 @@ def __radd__(self: int, other: int) -> int: def __rand__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("idivmod_s", num_params=2), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op( + builtins, + int_op("idivmod_s", num_params=2), + ReversingChecker(DefaultCallChecker()), + ) def __rdivmod__(self: int, other: int) -> tuple[int, int]: ... - @guppy.hugr_op(builtins, int_op("idiv_s", num_params=2), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op( + builtins, int_op("idiv_s", num_params=2), ReversingChecker(DefaultCallChecker()) + ) def __rfloordiv__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("ishl", num_params=2), ReversingChecker(DefaultCallChecker())) # TODO: RHS is unsigned + @guppy.hugr_op( + builtins, int_op("ishl", num_params=2), ReversingChecker(DefaultCallChecker()) + ) # TODO: RHS is unsigned def __rlshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("imod_s", num_params=2), ReversingChecker(DefaultCallChecker())) + @guppy.hugr_op( + builtins, int_op("imod_s", num_params=2), ReversingChecker(DefaultCallChecker()) + ) def __rmod__(self: int, other: int) -> int: ... @@ -171,11 +194,15 @@ def __ror__(self: int, other: int) -> int: def __round__(self: int) -> int: ... - @guppy.custom(builtins, NotImplementedCompiler("ipow"), ReversingChecker(DefaultCallChecker())) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("ipow"), ReversingChecker(DefaultCallChecker()) + ) # TODO def __rpow__(self: int, other: int) -> int: ... - @guppy.hugr_op(builtins, int_op("ishr", num_params=2), ReversingChecker(DefaultCallChecker())) # TODO: RHS is unsigned + @guppy.hugr_op( + builtins, int_op("ishr", num_params=2), ReversingChecker(DefaultCallChecker()) + ) # TODO: RHS is unsigned def __rrshift__(self: int, other: int) -> int: ... @@ -187,7 +214,9 @@ def __rshift__(self: int, other: int) -> int: def __rsub__(self: int, other: int) -> int: ... - @guppy.custom(builtins, IntTruedivCompiler(), ReversingChecker(DefaultCallChecker())) + @guppy.custom( + builtins, IntTruedivCompiler(), ReversingChecker(DefaultCallChecker()) + ) def __rtruediv__(self: int, other: int) -> float: ... @@ -214,7 +243,6 @@ def __xor__(self: int, other: int) -> int: @guppy.type(builtins, hugr_float_type, name="float") class Float: - @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) def __abs__(self: float) -> float: ... @@ -259,7 +287,9 @@ def __ge__(self: float, other: float) -> bool: def __gt__(self: float, other: float) -> bool: ... - @guppy.hugr_op(builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker()) + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) def __int__(self: float) -> int: ... @@ -303,7 +333,9 @@ def __radd__(self: float, other: float) -> float: def __rdivmod__(self: float, other: float) -> tuple[float, float]: ... - @guppy.custom(builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker())) + @guppy.custom( + builtins, FloatFloordivCompiler(), ReversingChecker(CoercingChecker()) + ) def __rfloordiv__(self: float, other: float) -> float: ... @@ -315,11 +347,15 @@ def __rmod__(self: float, other: float) -> float: def __rmul__(self: float, other: float) -> float: ... - @guppy.custom(builtins, NotImplementedCompiler("fround"), ReversingChecker(CoercingChecker())) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("fround"), ReversingChecker(CoercingChecker()) + ) # TODO def __round__(self: float) -> float: ... - @guppy.custom(builtins, NotImplementedCompiler("fpow"), ReversingChecker(CoercingChecker())) # TODO + @guppy.custom( + builtins, NotImplementedCompiler("fpow"), ReversingChecker(CoercingChecker()) + ) # TODO def __rpow__(self: float, other: float) -> float: ... @@ -339,7 +375,9 @@ def __sub__(self: float, other: float) -> float: def __truediv__(self: float, other: float) -> float: ... - @guppy.hugr_op(builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker()) + @guppy.hugr_op( + builtins, float_op("trunc_s", "arithmetic.conversions"), CoercingChecker() + ) def __trunc__(self: float) -> float: ... @@ -349,7 +387,9 @@ def abs(x): ... -@guppy.custom(builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False) +@guppy.custom( + builtins, name="bool", checker=DunderChecker("__bool__"), higher_order_value=False +) def _bool(x): ... @@ -359,22 +399,30 @@ def callable(x): ... -@guppy.custom(builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False) +@guppy.custom( + builtins, checker=DunderChecker("__divmod__", num_args=2), higher_order_value=False +) def divmod(x, y): ... -@guppy.custom(builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False) +@guppy.custom( + builtins, name="float", checker=DunderChecker("__float__"), higher_order_value=False +) def _float(x, y): ... -@guppy.custom(builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False) +@guppy.custom( + builtins, name="int", checker=DunderChecker("__int__"), higher_order_value=False +) def _int(x): ... -@guppy.custom(builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False) +@guppy.custom( + builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False +) def pow(x, y): ... @@ -382,4 +430,3 @@ def pow(x, y): @guppy.custom(builtins, checker=DunderChecker("__round__"), higher_order_value=False) def round(x): ... - diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index fd1b9bc3..3b40af88 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -16,7 +16,11 @@ def quantum_op(op_name: str) -> ops.OpType: return ops.CustomOp(extension="quantum.tket2", op_name=op_name, args=[]) -@guppy.type(quantum, tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), linear=True) +@guppy.type( + quantum, + tys.Opaque(extension="prelude", id="qubit", args=[], bound=TypeBound.Any), + linear=True, +) class Qubit: pass