Skip to content

Commit

Permalink
refactor: Turn linearity checking into separate compiler stage (#273)
Browse files Browse the repository at this point in the history
Closes #272
  • Loading branch information
mark-koch authored Jun 27, 2024
1 parent 792fb87 commit 99320b1
Show file tree
Hide file tree
Showing 16 changed files with 277 additions and 123 deletions.
5 changes: 4 additions & 1 deletion guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import itertools
from collections.abc import Iterator
from typing import NamedTuple
Expand Down Expand Up @@ -270,7 +271,9 @@ def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.Name:
# assignment statement and replace the expression with `x`.
if not isinstance(node.target, ast.Name):
raise InternalGuppyError(f"Unexpected assign target: {node.target}")
assign = ast.Assign(targets=[node.target], value=self.visit(node.value))
assign = ast.Assign(
targets=[copy.deepcopy(node.target)], value=self.visit(node.value)
)
set_location_from(assign, node)
self.bb.statements.append(assign)
return node.target
Expand Down
30 changes: 6 additions & 24 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from guppylang.cfg.cfg import CFG, BaseCFG
from guppylang.checker.core import Context, Globals, Locals, Variable
from guppylang.checker.expr_checker import ExprSynthesizer, to_bool
from guppylang.checker.linearity_checker import check_cfg_linearity
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.error import GuppyError
from guppylang.tys.ty import Type
Expand Down Expand Up @@ -84,7 +85,7 @@ def check_cfg(
while len(queue) > 0:
pred, num_output, bb = queue.popleft()
input_row = [
Variable(v.name, v.ty, v.defined_at, None)
Variable(v.name, v.ty, v.defined_at)
for v in pred.sig.output_rows[num_output]
]

Expand Down Expand Up @@ -114,6 +115,10 @@ def check_cfg(
checked_cfg.maybe_ass_before = {
compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
}

# Finally, run the linearity check
check_cfg_linearity(checked_cfg)

return checked_cfg


Expand Down Expand Up @@ -160,29 +165,6 @@ def check_bb(
)
raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x])

# We have to check that used linear variables are not being outputted
if x in ctx.locals:
var = ctx.locals[x]
if var.ty.linear and var.used:
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` was "
"already used (at {0})",
cfg.live_before[succ][x].vars.used[x],
[var.used],
)

# On the other hand, unused linear variables *must* be outputted
for x, var in ctx.locals.items():
if var.ty.linear and not var.used and x not in cfg.live_before[succ]:
# TODO: This should be "Variable x with linear type ty is not
# used in {bb}". But for this we need a way to associate BBs with
# source locations.
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` is "
"not used on all control-flow paths",
var.defined_at,
)

# Finally, we need to compute the signature of the basic block
outputs = [
[ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals]
Expand Down
3 changes: 1 addition & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@
)


@dataclass
@dataclass(frozen=True)
class Variable:
"""Class holding data associated with a local variable."""

name: str
ty: Type
defined_at: AstNode | None
used: AstNode | None


PyScope = dict[str, Any]
Expand Down
65 changes: 4 additions & 61 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
AstVisitor,
breaks_in_loop,
get_type_opt,
name_nodes_in_ast,
return_nodes_in_ast,
with_loc,
with_type,
Expand Down Expand Up @@ -333,14 +332,6 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]:
x = node.id
if x in self.ctx.locals:
var = self.ctx.locals[x]
if var.ty.linear and var.used is not None:
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` was "
"already used (at {0})",
node,
[var.used],
)
var.used = node
return with_loc(node, LocalName(id=x)), var.ty
elif x in self.ctx.globals:
# Look-up what kind of definition it is
Expand Down Expand Up @@ -874,34 +865,14 @@ def synthesize_comprehension(
"""Helper function to synthesise the element type of a list comprehension."""
from guppylang.checker.stmt_checker import StmtChecker

def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None:
"""Checks if an expression uses a linear variable from an outer scope.
Since the expression is executed multiple times in the inner scope, this would
mean that the outer linear variable is used multiple times, which is not
allowed.
"""
for name in name_nodes_in_ast(expr):
x = name.id
if x in locals and x not in locals.vars:
var = locals[x]
if var.ty.linear:
raise GuppyTypeError(
f"Variable `{x}` with linear type `{var.ty}` would be used "
"multiple times when evaluating this comprehension",
name,
)

# If there are no more generators left, we can check the list element
if not gens:
node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt)
check_linear_use_from_outer_scope(node.elt, ctx.locals)
return node, elt_ty

# Check the iterator in the outer context
gen, *gens = gens
gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign)
check_linear_use_from_outer_scope(gen.iter_assign.value, ctx.locals)

# The rest is checked in a new nested context to ensure that variables don't escape
# their scope
Expand All @@ -915,43 +886,15 @@ def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None:
gen.iter, iter_ty = expr_sth.visit_Name(gen.iter)
gen.iter = with_type(iter_ty, gen.iter)

# `if` guards are generally not allowed when we're iterating over linear variables.
# The only exception is if all linear variables are already consumed by the first
# guard
if gen.ifs:
gen.ifs[0], _ = expr_sth.synthesize(gen.ifs[0])

# Now, check if there are linear iteration variables that have not been used by
# the first guard
for target in name_nodes_in_ast(gen.next_assign.targets[0]):
var = inner_ctx.locals[target.id]
if var.ty.linear and not var.used and gen.ifs:
raise GuppyTypeError(
f"Variable `{var.name}` with linear type `{var.ty}` is not used on "
"all control-flow paths of the list comprehension",
target,
)

# Now, we can properly check all guards
for i in range(len(gen.ifs)):
gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)
check_linear_use_from_outer_scope(gen.ifs[i], inner_locals)
# Check `if` guards
for i in range(len(gen.ifs)):
gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)

# Check remaining generators
node, elt_ty = synthesize_comprehension(node, gens, inner_ctx)

# We have to make sure that all linear variables that were introduced in this scope
# have been used
for x, var in inner_ctx.locals.vars.items():
if var.ty.linear and not var.used:
raise GuppyTypeError(
f"Variable `{x}` with linear type `{var.ty}` is not used",
var.defined_at,
)

# The iter finalizer is again checked in the outer context
ctx.locals[gen.iter.id].used = None
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return node, elt_ty
Expand Down
6 changes: 3 additions & 3 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check_global_func_def(

cfg = CFGBuilder().build(func_def.body, returns_none, globals)
inputs = [
Variable(x, ty, loc, None)
Variable(x, ty, loc)
for x, ty, loc in zip(ty.input_names, ty.inputs, args, strict=True)
]
return check_cfg(cfg, inputs, ty.output, globals)
Expand Down Expand Up @@ -87,7 +87,7 @@ def check_nested_func_def(

# Construct inputs for checking the body CFG
inputs = list(captured.values()) + [
Variable(x, ty, func_def.args.args[i], None)
Variable(x, ty, func_def.args.args[i])
for i, (x, ty) in enumerate(
zip(func_ty.input_names, func_ty.inputs, strict=True)
)
Expand All @@ -111,7 +111,7 @@ def check_nested_func_def(
)
else:
# Otherwise, we treat it like a local name
inputs.append(Variable(func_def.name, func_def.ty, func_def, None))
inputs.append(Variable(func_def.name, func_def.ty, func_def))

checked_cfg = check_cfg(cfg, inputs, func_ty.output, globals)
checked_def = CheckedNestedFunctionDef(
Expand Down
Loading

0 comments on commit 99320b1

Please sign in to comment.