Skip to content

Commit

Permalink
Run formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Nov 21, 2023
1 parent 6db0327 commit 8616f6b
Show file tree
Hide file tree
Showing 21 changed files with 455 additions and 157 deletions.
4 changes: 3 additions & 1 deletion guppy/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None:
annotate_location(node, source, file, line_offset)


def annotate_location(node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True) -> None:
def annotate_location(
node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True
) -> None:
setattr(node, "line_offset", line_offset)
setattr(node, "file", file)
setattr(node, "source", source)
Expand Down
4 changes: 3 additions & 1 deletion guppy/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class BaseCFG(Generic[T]):
ass_before: Result[DefAssignmentDomain]
maybe_ass_before: Result[MaybeAssignmentDomain]

def __init__(self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None):
def __init__(
self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None
):
self.bbs = bbs
if entry_bb:
self.entry_bb = entry_bb
Expand Down
44 changes: 28 additions & 16 deletions guppy/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None:
self.output_ty = output_ty


def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedCFG:
def check_cfg(
cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals
) -> CheckedCFG:
"""Type checks a control-flow graph.
Annotates the basic blocks with input and output type signatures and removes
Expand All @@ -61,18 +63,24 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals)

# We start by compiling the entry BB
checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty)
checked_cfg.entry_bb = check_bb(cfg.entry_bb, checked_cfg, inputs, return_ty, globals)
checked_cfg.entry_bb = check_bb(
cfg.entry_bb, checked_cfg, inputs, return_ty, globals
)
compiled = {cfg.entry_bb: checked_cfg.entry_bb}

# Visit all control-flow edges in BFS order. We can't just do a normal loop over
# all BBs since the input types for a BB are computed by checking a predecessor.
# We do BFS instead of DFS to get a better error ordering.
queue = collections.deque(
(checked_cfg.entry_bb, i, succ) for i, succ in enumerate(cfg.entry_bb.successors)
(checked_cfg.entry_bb, i, succ)
for i, succ in enumerate(cfg.entry_bb.successors)
)
while len(queue) > 0:
pred, num_output, bb = queue.popleft()
input_row = [Variable(v.name, v.ty, v.defined_at, None) for v in pred.sig.output_rows[num_output]]
input_row = [
Variable(v.name, v.ty, v.defined_at, None)
for v in pred.sig.output_rows[num_output]
]

if bb in compiled:
# If the BB was already compiled, we just have to check that the signatures
Expand All @@ -81,9 +89,7 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals)
else:
# Otherwise, check the BB and enqueue its successors
checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals)
queue += [
(checked_bb, i, succ) for i, succ in enumerate(bb.successors)
]
queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)]
compiled[bb] = checked_bb

# Link up BBs in the checked CFG
Expand All @@ -94,11 +100,19 @@ def check_cfg(cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals)
checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable
checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs}
checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs}
checked_cfg.maybe_ass_before = {compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs}
checked_cfg.maybe_ass_before = {
compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
}
return checked_cfg


def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyType, globals: Globals) -> CheckedBB:
def check_bb(
bb: BB,
checked_cfg: CheckedCFG,
inputs: VarRow,
return_ty: GuppyType,
globals: Globals,
) -> CheckedBB:
cfg = bb.cfg

# For the entry BB we have to separately check that all used variables are
Expand Down Expand Up @@ -132,9 +146,7 @@ def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyTy
f"Variable `{x}` is not defined on all control-flow paths.",
use_bb.vars.used[x],
)
raise GuppyError(
f"Variable `{x}` is not defined", use_bb.vars.used[x]
)
raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x])

# We have to check that used linear variables are not being outputted
if x in ctx.locals:
Expand Down Expand Up @@ -166,7 +178,9 @@ def check_bb(bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, return_ty: GuppyTy
]

# Also prepare the successor list so we can fill it in later
checked_bb = CheckedBB(bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs))
checked_bb = CheckedBB(
bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs)
)
checked_bb.successors = [None] * len(bb.successors) # type: ignore
checked_bb.branch_pred = bb.branch_pred
return checked_bb
Expand All @@ -193,9 +207,7 @@ def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None:
v1, v2 = v2, v1
# We shouldn't mention temporary variables (starting with `%`)
# in error messages:
ident = (
"Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`"
)
ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`"
raise GuppyError(
f"{ident} can refer to different types: "
f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})",
Expand Down
18 changes: 14 additions & 4 deletions guppy/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
from typing import NamedTuple, Optional, Union

from guppy.ast_util import AstNode
from guppy.guppy_types import GuppyType, FunctionType, TupleType, SumType, NoneType, \
BoolType
from guppy.guppy_types import (
GuppyType,
FunctionType,
TupleType,
SumType,
NoneType,
BoolType,
)


@dataclass
Expand All @@ -25,11 +31,15 @@ class CallableVariable(ABC, Variable):
ty: FunctionType

@abstractmethod
def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context") -> ast.expr:
def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context"
) -> ast.expr:
"""Checks the return type of a function call against a given type."""

@abstractmethod
def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: "Context") -> tuple[ast.expr, GuppyType]:
def synthesize_call(
self, args: list[ast.expr], node: AstNode, ctx: "Context"
) -> tuple[ast.expr, GuppyType]:
"""Synthesizes the return type of a function call."""


Expand Down
42 changes: 32 additions & 10 deletions guppy/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def __init__(self, ctx: Context) -> None:
self.ctx = ctx
self._kind = "expression"

def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Optional[AstNode] = None) -> NoReturn:
def _fail(
self,
expected: GuppyType,
actual: Union[ast.expr, GuppyType],
loc: Optional[AstNode] = None,
) -> NoReturn:
"""Raises a type error indicating that the type doesn't match."""
if not isinstance(actual, GuppyType):
loc = loc or actual
Expand All @@ -62,7 +67,9 @@ def _fail(self, expected: GuppyType, actual: Union[ast.expr, GuppyType], loc: Op
f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc
)

def check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr:
def check(
self, expr: ast.expr, ty: GuppyType, kind: str = "expression"
) -> ast.expr:
"""Checks an expression against a type.
Returns a new desugared expression with type annotations.
Expand Down Expand Up @@ -108,7 +115,9 @@ def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]:
node, ty = self.visit(node)
return with_type(ty, node), ty

def _check(self, expr: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr:
def _check(
self, expr: ast.expr, ty: GuppyType, kind: str = "expression"
) -> ast.expr:
"""Checks an expression against a given type"""
return ExprChecker(self.ctx).check(expr, ty, kind)

Expand Down Expand Up @@ -218,7 +227,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
node.func, ty = self.synthesize(node.func)

# First handle direct calls of user-defined functions and extension functions
if isinstance(node.func, GlobalName) and isinstance(node.func.value, CallableVariable):
if isinstance(node.func, GlobalName) and isinstance(
node.func.value, CallableVariable
):
return node.func.value.synthesize_call(node.args, node, self.ctx)

# Otherwise, it must be a function as a higher-order value
Expand Down Expand Up @@ -263,15 +274,23 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None:
)


def synthesize_call(func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[list[ast.expr], GuppyType]:
def synthesize_call(
func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
) -> tuple[list[ast.expr], GuppyType]:
"""Synthesizes the return type of a function call"""
check_num_args(len(func_ty.args), len(args), node)
for i, arg in enumerate(args):
args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument")
return args, func_ty.returns


def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> list[ast.expr]:
def check_call(
func_ty: FunctionType,
args: list[ast.expr],
ty: GuppyType,
node: AstNode,
ctx: Context,
) -> list[ast.expr]:
"""Checks the return type of a function call against a given type"""
args, return_ty = synthesize_call(func_ty, args, node, ctx)
if return_ty != ty:
Expand All @@ -281,7 +300,9 @@ def check_call(func_ty: FunctionType, args: list[ast.expr], ty: GuppyType, node:
return args


def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr, GuppyType]:
def to_bool(
node: ast.expr, node_ty: GuppyType, ctx: Context
) -> tuple[ast.expr, GuppyType]:
"""Tries to turn a node into a bool"""
if isinstance(node_ty, BoolType):
return node, node_ty
Expand All @@ -299,12 +320,14 @@ def to_bool(node: ast.expr, node_ty: GuppyType, ctx: Context) -> tuple[ast.expr,
if not isinstance(return_ty, BoolType):
raise GuppyTypeError(
f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`",
node
node,
)
return call, return_ty


def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Optional[GuppyType]:
def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals
) -> Optional[GuppyType]:
"""Turns a primitive Python value into a Guppy type.
Returns `None` if the Python value cannot be represented in Guppy.
Expand All @@ -316,4 +339,3 @@ def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Opti
if isinstance(v, float):
return globals.types["float"].build(node=node)
return None

22 changes: 17 additions & 5 deletions guppy/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,27 @@
@dataclass
class DefinedFunction(CallableVariable):
"""A user-defined function"""

ty: FunctionType
defined_at: ast.FunctionDef

@staticmethod
def from_ast(func_def: ast.FunctionDef, name: str, globals: Globals) -> "DefinedFunction":
def from_ast(
func_def: ast.FunctionDef, name: str, globals: Globals
) -> "DefinedFunction":
ty = check_signature(func_def, globals)
return DefinedFunction(name, ty, func_def, None)

def check_call(self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context) -> GlobalCall:
def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context
) -> GlobalCall:
# Use default implementation from the expression checker
args = check_call(self.ty, args, ty, node, ctx)
return GlobalCall(func=self, args=args)

def synthesize_call(self, args: list[ast.expr], node: AstNode, ctx: Context) -> tuple[GlobalCall, GuppyType]:
def synthesize_call(
self, args: list[ast.expr], node: AstNode, ctx: Context
) -> tuple[GlobalCall, GuppyType]:
# Use default implementation from the expression checker
args, ty = synthesize_call(self.ty, args, node, ctx)
return GlobalCall(func=self, args=args), ty
Expand Down Expand Up @@ -57,7 +64,9 @@ def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFun
return CheckedFunction(func_def.name, func.ty, func_def, None, cfg)


def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) -> CheckedNestedFunctionDef:
def check_nested_func_def(
func_def: NestedFunctionDef, bb: BB, ctx: Context
) -> CheckedNestedFunctionDef:
"""Type checks a local (nested) function definition."""
func_ty = check_signature(func_def, ctx.globals)
assert func_ty.arg_names is not None
Expand Down Expand Up @@ -102,7 +111,10 @@ def check_nested_func_def(func_def: NestedFunctionDef, bb: BB, ctx: Context) ->
)

# Construct inputs for checking the body CFG
inputs = list(captured.values()) + [Variable(x, ty, func_def.args.args[i], None) for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args))]
inputs = list(captured.values()) + [
Variable(x, ty, func_def.args.args[i], None)
for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args))
]
globals = ctx.globals

# Check if the body contains a recursive occurrence of the function name
Expand Down
17 changes: 12 additions & 5 deletions guppy/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class StmtChecker(AstVisitor[BBStatement]):

ctx: Context
bb: BB
return_ty: GuppyType
Expand All @@ -27,7 +26,9 @@ def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]:
def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]:
return ExprSynthesizer(self.ctx).synthesize(node)

def _check_expr(self, node: ast.expr, ty: GuppyType, kind: str = "expression") -> ast.expr:
def _check_expr(
self, node: ast.expr, ty: GuppyType, kind: str = "expression"
) -> ast.expr:
return ExprChecker(self.ctx).check(node, ty, kind)

def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None:
Expand Down Expand Up @@ -74,14 +75,18 @@ def visit_Assign(self, node: ast.Assign) -> ast.stmt:

def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
if node.value is None:
raise GuppyError("Variable declaration is not supported. Assignment is required", node)
raise GuppyError(
"Variable declaration is not supported. Assignment is required", node
)
ty = type_from_ast(node.annotation, self.ctx.globals)
node.value = self._check_expr(node.value, ty)
self._check_assign(node.target, ty, node)
return node

def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt:
bin_op = with_loc(node, ast.BinOp(left=node.target, op=node.op, right=node.value))
bin_op = with_loc(
node, ast.BinOp(left=node.target, op=node.op, right=node.value)
)
assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op))
return self.visit_Assign(assign)

Expand All @@ -103,7 +108,9 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt:
from guppy.checker.func_checker import check_nested_func_def

func_def = check_nested_func_def(node, self.bb, self.ctx)
self.ctx.locals[func_def.name] = Variable(func_def.name, func_def.ty, func_def, None)
self.ctx.locals[func_def.name] = Variable(
func_def.name, func_def.ty, func_def, None
)
return func_def

def visit_If(self, node: ast.If) -> None:
Expand Down
Loading

0 comments on commit 8616f6b

Please sign in to comment.