diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py index b02e00ad..ba38ded8 100644 --- a/guppy/cfg/cfg.py +++ b/guppy/cfg/cfg.py @@ -19,7 +19,7 @@ NestedFunctionDef, BBStatement, ) -from guppy.compiler_base import VarMap, DFContainer, Variable, Globals +from guppy.compiler_base import VarMap, DFContainer, Variable, Globals, is_return_var from guppy.error import InternalGuppyError, GuppyError, GuppyTypeError from guppy.ast_util import AstVisitor, line_col, set_location_from from guppy.expression import ExpressionCompiler @@ -234,7 +234,8 @@ def _compile_bb( # We put all non-linear variables into the branch predicate and all # linear variables in the normal output (since they are shared between # all successors). This is in line with the definition of `<` on - # variables which puts linear variables at the end. + # variables which puts linear variables at the end. The only exception + # are return vars which must be outputted in order. branch_port = self._choose_vars_for_pred( graph=graph, pred=branch_port, @@ -242,7 +243,7 @@ def _compile_bb( sorted( x for x in self.live_before[succ] - if x in dfg and not dfg[x].ty.linear + if x in dfg and (not dfg[x].ty.linear or is_return_var(x)) ) for succ in bb.successors ], @@ -253,7 +254,7 @@ def _compile_bb( # We can look at `successors[0]` here since all successors must have # the same `live_before` linear variables for x in self.live_before[bb.successors[0]] - if x in dfg and dfg[x].ty.linear + if x in dfg and dfg[x].ty.linear and not is_return_var(x) ) graph.add_output( diff --git a/guppy/compiler_base.py b/guppy/compiler_base.py index 94613880..ecd4ff27 100644 --- a/guppy/compiler_base.py +++ b/guppy/compiler_base.py @@ -27,9 +27,12 @@ def __lt__(self, other: Any) -> bool: # We define an ordering on variables that is used to determine in which order # they are outputted from basic blocks. We need to output linear variables at # the end, so we do a lexicographic ordering of linearity and name, exploiting - # the fact that `False < True` in Python. + # the fact that `False < True` in Python. The only exception are return vars + # which must be outputted in order. if not isinstance(other, Variable): return NotImplemented + if is_return_var(self.name) and is_return_var(other.name): + return self.name < other.name return (self.ty.linear, self.name) < (other.ty.linear, other.name) @@ -241,3 +244,8 @@ def return_var(n: int) -> str: e1 ; %ret2 = e2`. This way, we can reuse our existing mechanism for passing of live variables between basic blocks.""" return f"%ret{n}" + + +def is_return_var(x: str) -> bool: + """Checks whether the given name is a dummy return variable.""" + return x.startswith("%ret") diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index decc4f5c..34e910bd 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -29,6 +29,18 @@ def test(q: Qubit) -> Qubit: validate(module.compile(True)) +def test_linear_return_order(validate): + # See https://github.com/CQCL-DEV/guppy/issues/35 + module = GuppyModule("test") + module.load(quantum) + + @module + def test(q: Qubit) -> tuple[Qubit, bool]: + return measure(q) + + validate(module.compile(True)) + + def test_interleave(validate): module = GuppyModule("test") module.load(quantum)