Skip to content

Commit

Permalink
feat: Array indexing (#415)
Browse files Browse the repository at this point in the history
This is the feature branch for array indexing. Tracked by #253.

The goal is to allow array indexing in `@inout` positions that ignores
linearity constraints in the indices:

```python
qs: array[qubit, 42]
cx(qs[0], qs[1])  # Ok
cx(qs[0], qs[0])  # Compiles, but will panic at runtime
q = qs[0]  # Error: Indexing only allowed in inout position
```

This is achieved by lowering `array[qubit]` to `array[qubit | None]` and
making `array.__getitem__` and `array.__setitem__` swap in `None`
whenever a qubit is taken out or put back in. The functions
`array.__getitem__` and `array.__setitem__` take the array as `@inout`,
so it suffices to apply the following desugaring:

```python
cx(qs[expr1], qs[expr2])

# Desugars to
idx1 = expr1
idx2 = expr2
tmp1 = qs.__getitem__(idx1)
tmp2 = qs.__getitem__(idx2)
cx(tmp1, tmp2)
qs.__setitem__(tmp1, idx1)
qs.__setitem__(tmp2, idx2)
```

This also means that arrays containing structs might behave slightly
unexpectedly:

```python
@guppy.struct
class QubitPair:
    q1: qubit
    q2: qubit

ps: array[QubitPair, 42]
cx(ps[0].q1, ps[0].q2)  # Panics at runtime :(

# Since it desugars to
idx1 = 0
idx2 = 0
tmp1 = ps.__getitem__(idx1)
tmp2 = ps.__getitem__(idx2)  # Panic: Struct at index 0 has already been replaced with None
cx(tmp1.q1, tmp2.q2)
...
```

To solve this, we need to change the Hugr lowering of arrays to
structually replace any occurrence of `qubit` with `qubit | None`, i.e.
instead of doing `array[QubitPair | None]`, we would need to do
`array[tuple[qubit | None, qubit | None]]`. I'll leave this to a future
PR in case we are interested in that.

The following feature PRs target this branch:
* #420 
* #421
* #422
* #447
  • Loading branch information
mark-koch authored Sep 4, 2024
1 parent 553ec51 commit 2199b48
Show file tree
Hide file tree
Showing 26 changed files with 747 additions and 43 deletions.
43 changes: 41 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
#:
#: All places are equipped with a unique id, a type and an optional definition AST
#: location. During linearity checking, they are tracked separately.
Place: TypeAlias = "Variable | FieldAccess"
Place: TypeAlias = "Variable | FieldAccess | SubscriptAccess"

#: Unique identifier for a `Place`.
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id"
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id | SubscriptAccess.Id"


@dataclass(frozen=True)
Expand Down Expand Up @@ -154,6 +154,45 @@ def replace_defined_at(self, node: AstNode | None) -> "FieldAccess":
return replace(self, exact_defined_at=node)


@dataclass(frozen=True)
class SubscriptAccess:
"""A place identifying a subscript `place[item]` access."""

parent: Place
item: Variable
ty: Type
item_expr: ast.expr
getitem_call: ast.expr
#: Only populated if this place occurs in an inout position
setitem_call: ast.expr | None = None

@dataclass(frozen=True)
class Id:
"""Identifier for subscript places."""

parent: PlaceId
item: Variable.Id

@cached_property
def id(self) -> "SubscriptAccess.Id":
"""The unique `PlaceId` identifier for this place."""
return SubscriptAccess.Id(self.parent.id, self.item.id)

@cached_property
def defined_at(self) -> AstNode | None:
"""Optional location where this place was last assigned to."""
return self.parent.defined_at

@property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Subscript `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}[...]"


PyScope = dict[str, Any]


Expand Down
82 changes: 74 additions & 8 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import traceback
from contextlib import suppress
from dataclasses import replace
from typing import Any, NoReturn, cast

from guppylang.ast_util import (
Expand All @@ -35,12 +36,15 @@
with_loc,
with_type,
)
from guppylang.cfg.builder import tmp_vars
from guppylang.checker.core import (
Context,
DummyEvalDict,
FieldAccess,
Globals,
Locals,
Place,
SubscriptAccess,
Variable,
)
from guppylang.definition.common import Definition
Expand All @@ -58,6 +62,7 @@
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
InoutReturnSentinel,
IterEnd,
IterHasNext,
IterNext,
Expand All @@ -66,6 +71,7 @@
PartialApply,
PlaceNode,
PyExpr,
SubscriptAccessAndDrop,
TensorCall,
TypeApply,
)
Expand Down Expand Up @@ -491,7 +497,7 @@ def _synthesize_binary(
node,
)

def _synthesize_instance_func(
def synthesize_instance_func(
self,
node: ast.expr,
args: list[ast.expr],
Expand Down Expand Up @@ -539,16 +545,37 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:

def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
item_expr, item_ty = self.synthesize(node.slice)
# Give the item a unique name so we can refer to it later in case we also want
# to compile a call to `__setitem__`
item = Variable(next(tmp_vars), item_ty, item_expr)
item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item)))
# Check a call to the `__getitem__` instance function
exp_sig = FunctionType(
[
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.Inout),
FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags),
],
ExistentialTypeVar.fresh("Val", False),
)
return self._synthesize_instance_func(
node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig
getitem_expr, result_ty = self.synthesize_instance_func(
node.value, [item_node], "__getitem__", "not subscriptable", exp_sig
)
# Subscripting a place is itself a place
expr: ast.expr
if isinstance(node.value, PlaceNode):
place = SubscriptAccess(
node.value.place, item, result_ty, item_expr, getitem_expr
)
expr = PlaceNode(place=place)
else:
# If the subscript is not on a place, then there is no way to address the
# other indices after this one has been projected out (e.g. `f()[0]` makes
# you loose access to all elements besides 0).
expr = SubscriptAccessAndDrop(
item=item, item_expr=item_expr, getitem_expr=getitem_expr
)
return with_loc(node, expr), result_ty

def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
if len(node.keywords) > 0:
Expand Down Expand Up @@ -600,7 +627,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self._synthesize_instance_func(
expr, ty = self.synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
)

Expand All @@ -624,7 +651,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

Expand All @@ -634,14 +661,14 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
[FuncInput(ty, InputFlags.NoFlags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__next__", "not an iterator", exp_sig, True
)

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)

Expand Down Expand Up @@ -764,6 +791,8 @@ def type_check_args(
new_args: list[ast.expr] = []
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
a.place = check_inout_arg_place(a.place, ctx, a)
new_args.append(a)
subst |= s

Expand All @@ -784,6 +813,43 @@ def type_check_args(
return new_args, subst


def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
"""Performs additional checks for place arguments in @inout position.
In particular, we need to check that places involving `place[item]` subscripts
implement the corresponding `__setitem__` method.
"""
match place:
case Variable():
return place
case FieldAccess(parent=parent):
return replace(place, parent=check_inout_arg_place(parent, ctx, node))
case SubscriptAccess(parent=parent, item=item, ty=ty):
# Check a call to the `__setitem__` instance function
exp_sig = FunctionType(
[
FuncInput(parent.ty, InputFlags.Inout),
FuncInput(item.ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.NoFlags),
],
NoneType(),
)
setitem_args = [
with_type(parent.ty, with_loc(node, PlaceNode(parent))),
with_type(item.ty, with_loc(node, PlaceNode(item))),
with_type(ty, with_loc(node, InoutReturnSentinel(var=place))),
]
setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
setitem_args[0],
setitem_args[1:],
"__setitem__",
"not allowed in a subscripted `@inout` position",
exp_sig,
True,
)
return replace(place, setitem_call=setitem_call)


def synthesize_call(
func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
) -> tuple[list[ast.expr], Type, Inst]:
Expand Down
77 changes: 66 additions & 11 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Locals,
Place,
PlaceId,
SubscriptAccess,
Variable,
)
from guppylang.definition.custom import CustomFunctionDef
Expand All @@ -33,9 +34,15 @@
LocalCall,
PartialApply,
PlaceNode,
SubscriptAccessAndDrop,
TensorCall,
)
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
StructType,
)


class Scope(Locals[PlaceId, Place]):
Expand Down Expand Up @@ -137,16 +144,31 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non
"ownership of the value.",
node,
)
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
# Places involving subscripts are handled differently since we ignore everything
# after the subscript for the purposes of linearity checking
if subscript := contains_subscript(node.place):
if not is_inout_arg and subscript.parent.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` was already "
"used (at {0})",
"Subscripting on expression with linear type "
f"`{subscript.parent.ty}` is only allowed in `@inout` position",
node,
[use],
)
self.scope.use(x, node)
self.scope.assign(subscript.item)
# Visiting the `__getitem__(place.parent, place.item)` call ensures that we
# linearity-check the parent and element.
self.visit(subscript.getitem_call)
# For all other places, we record uses of all leafs
else:
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` was already "
"used (at {0})",
node,
[use],
)
self.scope.use(x, node)

def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
Expand All @@ -170,9 +192,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
for leaf in leaf_places(place):
leaf = leaf.replace_defined_at(arg)
self.scope.assign(leaf)
self._reassign_single_inout_arg(place, arg)
case arg if inp.ty.linear:
raise GuppyError(
f"Inout argument with linear type `{inp.ty}` would be "
Expand All @@ -182,6 +202,19 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
arg,
)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
"""Helper function to reassign a single inout argument after a function call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
if subscript := contains_subscript(place):
assert subscript.setitem_call is not None
self.visit(subscript.setitem_call)
self._reassign_single_inout_arg(subscript.parent, node)
else:
for leaf in leaf_places(place):
assert not isinstance(leaf, SubscriptAccess)
leaf = leaf.replace_defined_at(node)
self.scope.assign(leaf)

def visit_GlobalCall(self, node: GlobalCall) -> None:
func = self.globals[node.def_id]
assert isinstance(func, CallableDef)
Expand Down Expand Up @@ -233,6 +266,19 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None:
node.value,
)

def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None:
# A subscript access on a value that is not a place. This means the value can no
# longer be accessed after the item has been projected out. Thus, this is only
# legal if the items in the container are not linear
elem_ty = get_type(node.getitem_expr)
if elem_ty.linear:
raise GuppyTypeError(
f"Remaining linear items with type `{elem_ty}` are not used", node
)
self.visit(node.item_expr)
self.scope.assign(node.item)
self.visit(node.getitem_expr)

def visit_Expr(self, node: ast.Expr) -> None:
# An expression statement where the return value is discarded
self.visit(node.value)
Expand Down Expand Up @@ -376,6 +422,15 @@ def leaf_places(place: Place) -> Iterator[Place]:
yield place


def contains_subscript(place: Place) -> SubscriptAccess | None:
"""Checks if a place contains a subscript access and returns the rightmost one."""
while not isinstance(place, Variable):
if isinstance(place, SubscriptAccess):
return place
place = place.parent
return None


def is_inout_var(place: Place) -> TypeGuard[Variable]:
"""Checks whether a place is an @inout variable."""
return isinstance(place, Variable) and InputFlags.Inout in place.flags
Expand Down
3 changes: 3 additions & 0 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __setitem__(self, place: Place, port: Wire) -> None:
else:
self.locals[place.id] = port

def __contains__(self, place: Place) -> bool:
return place.id in self.locals

def __copy__(self) -> "DFContainer":
# Make a copy of the var map so that mutating the copy doesn't
# mutate our variable mapping
Expand Down
Loading

0 comments on commit 2199b48

Please sign in to comment.