Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make True and False branches unconditional #740

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions guppylang/cfg/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,21 @@ class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):

stats: dict[BB, VariableStats[VId]]

def __init__(self, stats: dict[BB, VariableStats[VId]]) -> None:
def __init__(
self,
stats: dict[BB, VariableStats[VId]],
initial: LivenessDomain[VId] | None = None,
) -> None:
self.stats = stats
self._initial = initial or {}

def eq(self, live1: LivenessDomain[VId], live2: LivenessDomain[VId]) -> bool:
# Only check that both contain the same variables. We don't care about the BB
# in which the use occurs, we just need any one, to report to the user.
return live1.keys() == live2.keys()

def initial(self) -> LivenessDomain[VId]:
return {}
return self._initial

def join(self, *ts: LivenessDomain[VId]) -> LivenessDomain[VId]:
res: LivenessDomain[VId] = {}
Expand Down Expand Up @@ -183,6 +188,9 @@ def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
def apply_bb(
self, val_before: AssignmentDomain[VId], bb: BB
) -> AssignmentDomain[VId]:
# For unreachable BBs, we can assume that everything is assigned
if not bb.predecessors and bb != bb.containing_cfg.entry_bb:
return self.all_vars, self.all_vars
stats = self.stats[bb]
def_ass_before, maybe_ass_before = val_before
return (
Expand Down
34 changes: 32 additions & 2 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,34 @@ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) ->
nodes, self.cfg.entry_bb, Jumps(self.cfg.exit_bb, None, None)
)

# Compute reachable BBs
reachable = self.cfg.reachable_from(self.cfg.entry_bb)

# If we're still in a basic block after compiling the whole body, we have to add
# an implicit void return
if final_bb is not None:
if final_bb is not None and final_bb in reachable:
if not returns_none:
raise GuppyError(ExpectedError(nodes[-1], "return statement"))
self.cfg.link(final_bb, self.cfg.exit_bb)

reachable.add(self.cfg.exit_bb)

# Complain about unreachable code
unreachable = set(self.cfg.bbs) - reachable
for bb in unreachable:
if bb.statements:
raise GuppyError(UnreachableError(bb.statements[0]))
if bb.branch_pred:
raise GuppyError(UnreachableError(bb.branch_pred))
# Empty unreachable BBs are fine, we just prune them
if bb.successors:
# Since there is no branch expression, there can be at most a single
# successor
[succ] = bb.successors
succ.predecessors.remove(bb)

# If we made it till here, there are only the reachable BBs left. The only
# exception is the exit BB which should never be dropped, even if unreachable.
self.cfg.bbs = list(reachable | {self.cfg.exit_bb})
return self.cfg

def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
Expand Down Expand Up @@ -368,6 +389,15 @@ def add_branch(node: ast.expr, cfg: CFG, bb: BB, true_bb: BB, false_bb: BB) -> N
builder = BranchBuilder(cfg)
builder.visit(node, bb, true_bb, false_bb)

def visit_Constant(
self, node: ast.Constant, bb: BB, true_bb: BB, false_bb: BB
) -> None:
# Branching on `True` or `False` constant should be unconditional
if isinstance(node.value, bool):
self.cfg.link(bb, true_bb if node.value else false_bb)
else:
self.generic_visit(node, bb, true_bb, false_bb)

def visit_BoolOp(self, node: ast.BoolOp, bb: BB, true_bb: BB, false_bb: BB) -> None:
# Add short-circuit evaluation of boolean expression. If there are more than 2
# operators, we turn the flat operator list into a right-nested tree to allow
Expand Down
19 changes: 18 additions & 1 deletion guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def ancestors(self, *bbs: T) -> Iterator[T]:
yield bb
queue += bb.predecessors

def reachable_from(self, bb: T) -> set[T]:
"""Returns the set of all BBs reachable from some given BB."""
queue = {bb}
reachable = set()
while queue:
bb = queue.pop()
if bb not in reachable:
reachable.add(bb)
for succ in bb.successors:
queue.add(succ)
return reachable


class CFG(BaseCFG[BB]):
"""A control-flow graph of unchecked basic blocks."""
Expand Down Expand Up @@ -84,7 +96,12 @@ def analyze(
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
# Mark all borrowed variables as implicitly used in the exit BB
stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
# This also means borrowed variables are always live, so we can use them as the
# initial value in the liveness analysis. This solves the edge case that
# borrowed variables should be considered live, even if the exit is actually
# unreachable (to avoid linearity violations later).
inout_live = {x: self.exit_bb for x in inout_vars}
self.live_before = LivenessAnalysis(stats, initial=inout_live).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
stats, def_ass_before, maybe_ass_before
).run_unpacked(self.bbs)
Expand Down
15 changes: 13 additions & 2 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def check_cfg(
"""
# First, we need to run program analysis
ass_before = {v.name for v in inputs}
inout_vars = [v.name for v in inputs if InputFlags.Inout in v.flags]
cfg.analyze(ass_before, ass_before, inout_vars)
inout_vars = [v for v in inputs if InputFlags.Inout in v.flags]
cfg.analyze(ass_before, ass_before, [v.name for v in inout_vars])

# We start by compiling the entry BB
checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty)
Expand Down Expand Up @@ -123,6 +123,17 @@ def check_cfg(
compiled[bb].predecessors.append(pred)
pred.successors[num_output] = compiled[bb]

# The exit BB might be unreachable. In that case it won't be visited above and we
# have to handle it here
if cfg.exit_bb not in compiled:
assert len(cfg.exit_bb.predecessors) == 0
assert len(cfg.exit_bb.successors) == 0
assert len(cfg.exit_bb.statements) == 0
assert cfg.exit_bb.branch_pred is None
compiled[cfg.exit_bb] = CheckedBB(
cfg.exit_bb.idx, checked_cfg, sig=Signature(inout_vars, [])
)

checked_cfg.bbs = list(compiled.values())
checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable
checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs}
Expand Down
25 changes: 22 additions & 3 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, NamedTuple, TypeGuard

from guppylang.ast_util import AstNode, find_nodes, get_type
from guppylang.cfg.analysis import LivenessAnalysis
from guppylang.cfg.analysis import LivenessAnalysis, LivenessDomain
from guppylang.cfg.bb import BB, VariableStats
from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Row, Signature
from guppylang.checker.core import (
Expand Down Expand Up @@ -610,9 +610,28 @@ def check_cfg_linearity(
for leaf in leaf_places(var):
exit_scope.use(leaf.id, InoutReturnSentinel(var=var), UseKind.RETURN)

# Run liveness analysis
# Edge case: If the exit is unreachable, then the function will never terminate, so
# there is no need to give the borrowed values back to the caller. To ensure that
# the generated Hugr is still valid, we have to thread the borrowed arguments
# through the non-terminating loop. We achieve this by considering borrowed
# variables as live in every BB, even if the actual use in the exit is unreachable.
# This is done by including borrowed vars in the initial value for the liveness
# analysis below. The analogous thing was also done in the previous `CFG.analyze`
# pass.
live_default: LivenessDomain[PlaceId] = (
{
leaf.id: cfg.exit_bb
for var in cfg.entry_bb.sig.input_row
if InputFlags.Inout in var.flags
for leaf in leaf_places(var)
}
if not cfg.exit_bb.predecessors
else {}
)

# Run liveness analysis with this initial value
stats = {bb: scope.stats() for bb, scope in scopes.items()}
live_before = LivenessAnalysis(stats).run(cfg.bbs)
live_before = LivenessAnalysis(stats, initial=live_default).run(cfg.bbs)

# Construct a CFG that tracks places instead of just variables
result_cfg: CheckedCFG[Place] = CheckedCFG(cfg.input_tys, cfg.output_ty)
Expand Down
11 changes: 11 additions & 0 deletions guppylang/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def compile_cfg(

builder = container.add_cfg(*inputs)

# Explicitly annotate the output types since Hugr can't infer them if the exit is
# unreachable
out_tys = [place.ty.to_hugr() for place in cfg.exit_bb.sig.input_row]
# TODO: Use proper API for this once it's added in hugr-py:
# https://github.com/CQCL/hugr/issues/1816
builder._exit_op._cfg_outputs = out_tys
builder.parent_op._outputs = out_tys
builder.parent_node = builder.hugr._update_node_outs(
builder.parent_node, len(out_tys)
)

blocks: dict[CheckedBB[Place], ToNode] = {}
for bb in cfg.bbs:
blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, globals)
Expand Down
10 changes: 10 additions & 0 deletions tests/error/linear_errors/unused_non_terminating.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error: Linearity violation (at $FILE:13:8)
|
11 |
12 | @guppy(module)
13 | def foo(q: qubit @owned) -> None:
| ^^^^^^^^^^^^^^^ Variable `q` with linear type `qubit` is leaked
Comment on lines +1 to +6
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably, this shouldn't be an error since non-terminating functions technically cannot leak. However, it would be difficult to generate valid Hugr for cases like this, e.g.

while True:
    qubit()

Therefore, I think it's ok to still complain about leaks, even if the function provably doesn't terminate


Help: Make sure that `q` is consumed or returned to avoid the leak

Guppy compilation failed due to 1 previous error
18 changes: 18 additions & 0 deletions tests/error/linear_errors/unused_non_terminating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import guppylang.std.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import owned
from guppylang.std.quantum import qubit


module = GuppyModule("test")
module.load_all(quantum)


@guppy(module)
def foo(q: qubit @owned) -> None:
while True:
pass


module.compile()
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_10.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Unreachable (at $FILE:8:4)
|
6 | while True:
7 | x += 1
8 | return x
| ^^^^^^^^ This code is not reachable

Guppy compilation failed due to 1 previous error
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> int:
while True:
x += 1
return x
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_11.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Unreachable (at $FILE:9:8)
|
7 | if True:
8 | break
9 | x += 1
| ^^^^^^ This code is not reachable

Guppy compilation failed due to 1 previous error
10 changes: 10 additions & 0 deletions tests/error/misc_errors/unreachable_11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> int:
while not False:
if True:
break
x += 1
return x
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_6.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Unreachable (at $FILE:7:8)
|
5 | def foo(x: int) -> int:
6 | if False:
7 | x += 1
| ^^^^^^ This code is not reachable

Guppy compilation failed due to 1 previous error
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> int:
if False:
x += 1
return x
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_7.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Unreachable (at $FILE:8:4)
|
6 | if True:
7 | return 1
8 | return 0
| ^^^^^^^^ This code is not reachable

Guppy compilation failed due to 1 previous error
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from tests.util import compile_guppy


@compile_guppy
def foo() -> int:
if True:
return 1
return 0
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_8.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Unreachable (at $FILE:9:8)
|
7 | x += 1
8 | else:
9 | x -= 1
| ^^^^^^ This code is not reachable

Guppy compilation failed due to 1 previous error
10 changes: 10 additions & 0 deletions tests/error/misc_errors/unreachable_8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> int:
if not False:
x += 1
else:
x -= 1
return x
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_9.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Unreachable (at $FILE:6:18)
|
4 | @compile_guppy
5 | def foo(x: int) -> int:
6 | if False and (x := x + 1):
| ^^^^^^^^^^ This code is not reachable

Guppy compilation failed due to 1 previous error
8 changes: 8 additions & 0 deletions tests/error/misc_errors/unreachable_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> int:
if False and (x := x + 1):
pass
return x
Loading
Loading