Skip to content

Commit

Permalink
fix: Fix array lowering bugs (#575)
Browse files Browse the repository at this point in the history
Fixes #573 together with other list lowering issues.

There were some inconsistencies whether `None` is `Tag(0)` or `Tag(1)`.
Now we consistently use `ht.Option` and the `build_unwrap` helper.
  • Loading branch information
mark-koch authored Oct 18, 2024
1 parent 117b68e commit 83b9f31
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 185 deletions.
294 changes: 128 additions & 166 deletions guppylang/prelude/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import hugr.std
from hugr import Wire, ops
from hugr import tys as ht
Expand All @@ -12,229 +10,193 @@
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
from guppylang.prelude._internal.compiler.arithmetic import convert_itousize
from guppylang.prelude._internal.compiler.prelude import build_error, build_panic
from guppylang.prelude._internal.compiler.prelude import (
build_unwrap,
build_unwrap_left,
build_unwrap_right,
)
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.const import ConstValue

if TYPE_CHECKING:
from hugr.build.dfg import DfBase
# ------------------------------------------------------
# --------------- std.array operations -----------------
# ------------------------------------------------------


def _instantiate_array_op(
name: str, elem_ty: ht.Type, length: int, inp: list[ht.Type], out: list[ht.Type]
) -> ops.ExtOp:
return hugr.std.PRELUDE.get_op(name).instantiate(
[ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)], ht.FunctionType(inp, out)
)


def array_type(length: int, elem_ty: ht.Type) -> ht.ExtType:
def array_type(elem_ty: ht.Type, length: int) -> ht.ExtType:
"""Returns the hugr type of a fixed length array."""
length_arg = ht.BoundedNatArg(length)
elem_arg = ht.TypeTypeArg(elem_ty)
return hugr.std.PRELUDE.types["array"].instantiate([length_arg, elem_arg])


def array_new(elem_ty: ht.Type, length: int) -> ops.ExtOp:
"""Returns an operation that creates a new fixed length array."""
arr_ty = array_type(elem_ty, length)
return _instantiate_array_op(
"new_array", elem_ty, length, [elem_ty] * length, [arr_ty]
)


def array_get(elem_ty: ht.Type, length: int) -> ops.ExtOp:
"""Returns an array `get` operation."""
assert elem_ty.type_bound() == ht.TypeBound.Copyable
arr_ty = array_type(elem_ty, length)
return _instantiate_array_op(
"get", elem_ty, length, [arr_ty, ht.USize()], [ht.Option(elem_ty)]
)


def array_set(elem_ty: ht.Type, length: int) -> ops.ExtOp:
"""Returns an array `set` operation."""
arr_ty = array_type(elem_ty, length)
return _instantiate_array_op(
"set",
elem_ty,
length,
[arr_ty, ht.USize(), elem_ty],
[ht.Either([elem_ty, arr_ty], [elem_ty, arr_ty])],
)


# ------------------------------------------------------
# --------- Custom compilers for non-native ops --------
# ------------------------------------------------------


class NewArrayCompiler(CustomCallCompiler):
class ArrayCompiler(CustomCallCompiler):
"""Base class for custom array op compilers."""

@property
def elem_ty(self) -> ht.Type:
"""The element type for the array op that is being compiled."""
match self.type_args:
case [TypeArg(ty=elem_ty), _]:
return elem_ty.to_hugr()
case _:
raise InternalGuppyError("Invalid array type args")

@property
def length(self) -> int:
"""The length for the array op that is being compiled."""
match self.type_args:
case [_, ConstArg(ConstValue(value=int(length)))]:
return length
case _:
raise InternalGuppyError("Invalid array type args")


class NewArrayCompiler(ArrayCompiler):
"""Compiler for the `array.__new__` function."""

def build_classical_array(self, elems: list[Wire]) -> Wire:
"""Lowers a call to `array.__new__` for classical arrays."""
return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems)

def build_linear_array(self, elems: list[Wire]) -> Wire:
"""Lowers a call to `array.__new__` for linear arrays."""
elem_opt_ty = ht.Option(self.elem_ty)
elem_opts = [
self.builder.add_op(ops.Tag(1, elem_opt_ty), elem) for elem in elems
]
return self.builder.add_op(array_new(elem_opt_ty, len(elems)), *elem_opts)

def compile(self, args: list[Wire]) -> list[Wire]:
match self.type_args:
case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]:
op = new_array(length, elem_ty.to_hugr())
return [self.builder.add_op(op, *args)]
case type_args:
raise InternalGuppyError(f"Invalid array type args: {type_args}")
if self.elem_ty.type_bound() == ht.TypeBound.Any:
return [self.build_linear_array(args)]
else:
return [self.build_classical_array(args)]


class ArrayGetitemCompiler(CustomCallCompiler):
class ArrayGetitemCompiler(ArrayCompiler):
"""Compiler for the `array.__getitem__` function."""

def build_classical_getitem(
self,
array: Wire,
array_ty: ht.Type,
idx: Wire,
idx_ty: ht.Type,
elem_ty: ht.Type,
) -> CallReturnWires:
def build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
"""Lowers a call to `array.__getitem__` for classical arrays."""
length = self.type_args[1].to_hugr()
elem = build_array_get(
self.builder, array, array_ty, idx, idx_ty, elem_ty, length
)
idx = self.builder.add_op(convert_itousize(), idx)
result = self.builder.add_op(array_get(self.elem_ty, self.length), array, idx)
elem = build_unwrap(self.builder, result, "Array index out of bounds")
return CallReturnWires(regular_returns=[elem], inout_returns=[array])

def build_linear_getitem(
self,
array: Wire,
array_ty: ht.Type,
idx: Wire,
idx_ty: ht.Type,
elem_ty: ht.Type,
) -> CallReturnWires:
def build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
"""Lowers a call to `array.__getitem__` for linear arrays."""
# Swap out the element at the given index with `None`. The `to_hugr`
# implementation of the array type ensures that linear element types are turned
# into optionals.
elem_opt_ty = ht.Sum([[elem_ty], []])
none = self.builder.add_op(ops.Tag(1, elem_opt_ty))
length = self.type_args[1].to_hugr()
array, elem_opt = build_array_set(
self.builder,
array,
array_ty,
idx,
idx_ty,
none,
elem_opt_ty,
length,
elem_opt_ty = ht.Option(self.elem_ty)
none = self.builder.add_op(ops.Tag(0, elem_opt_ty))
idx = self.builder.add_op(convert_itousize(), idx)
result = self.builder.add_op(
array_set(elem_opt_ty, self.length), array, idx, none
)
elem_opt, array = build_unwrap_right(
self.builder, result, "Array index out of bounds"
)
elem = build_unwrap(
self.builder, elem_opt, "Linear array element has already been used"
)
# Make sure that the element we got out is not None
conditional = self.builder.add_conditional(elem_opt)
with conditional.add_case(0) as case:
case.set_outputs(*case.inputs())
with conditional.add_case(1) as case:
error = build_error(case, 1, "Linear array element has already been used")
case.set_outputs(*build_panic(case, [], [elem_ty], error))
return CallReturnWires(regular_returns=[conditional], inout_returns=[array])
return CallReturnWires(regular_returns=[elem], inout_returns=[array])

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
[array, idx] = args
[array_ty, idx_ty] = self.ty.input
[elem_ty, *_] = self.ty.output
if elem_ty.type_bound() == ht.TypeBound.Any:
return self.build_linear_getitem(array, array_ty, idx, idx_ty, elem_ty)
if self.elem_ty.type_bound() == ht.TypeBound.Any:
return self.build_linear_getitem(array, idx)
else:
return self.build_classical_getitem(array, array_ty, idx, idx_ty, elem_ty)
return self.build_classical_getitem(array, idx)

def compile(self, args: list[Wire]) -> list[Wire]:
raise InternalGuppyError("Call compile_with_inouts instead")


class ArraySetitemCompiler(CustomCallCompiler):
class ArraySetitemCompiler(ArrayCompiler):
"""Compiler for the `array.__setitem__` function."""

def build_classical_setitem(
self,
array: Wire,
array_ty: ht.Type,
idx: Wire,
idx_ty: ht.Type,
elem: Wire,
elem_ty: ht.Type,
length: ht.TypeArg,
self, array: Wire, idx: Wire, elem: Wire
) -> CallReturnWires:
"""Lowers a call to `array.__setitem__` for classical arrays."""
array, _ = build_array_set(
self.builder, array, array_ty, idx, idx_ty, elem, elem_ty, length
idx = self.builder.add_op(convert_itousize(), idx)
result = self.builder.add_op(
array_set(self.elem_ty, self.length), array, idx, elem
)
# Unwrap the result, but we don't have to hold onto the returned old value
_, array = build_unwrap_right(self.builder, result, "Array index out of bounds")
return CallReturnWires(regular_returns=[], inout_returns=[array])

def build_linear_setitem(
self,
array: Wire,
array_ty: ht.Type,
idx: Wire,
idx_ty: ht.Type,
elem: Wire,
elem_ty: ht.Type,
length: ht.TypeArg,
self, array: Wire, idx: Wire, elem: Wire
) -> CallReturnWires:
"""Lowers a call to `array.__setitem__` for linear arrays."""
# Embed the element into an optional
elem_opt_ty = ht.Sum([[elem_ty], []])
elem = self.builder.add_op(ops.Tag(0, elem_opt_ty), elem)
array, old_elem = build_array_set(
self.builder, array, array_ty, idx, idx_ty, elem, elem_opt_ty, length
elem_opt_ty = ht.Option(self.elem_ty)
elem = self.builder.add_op(ops.Tag(1, elem_opt_ty), elem)
idx = self.builder.add_op(convert_itousize(), idx)
result = self.builder.add_op(
array_set(elem_opt_ty, self.length), array, idx, elem
)
old_elem_opt, array = build_unwrap_right(
self.builder, result, "Array index out of bounds"
)
# Check that the old element was `None`
conditional = self.builder.add_conditional(old_elem)
with conditional.add_case(0) as case:
error = build_error(case, 1, "Linear array element has not been used")
build_panic(case, [elem_ty], [], error, *case.inputs())
case.set_outputs()
with conditional.add_case(1) as case:
case.set_outputs()
build_unwrap_left(
self.builder, old_elem_opt, "Linear array element has not been used"
)
return CallReturnWires(regular_returns=[], inout_returns=[array])

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
[array, idx, elem] = args
[array_ty, idx_ty, elem_ty] = self.ty.input
length = self.type_args[1].to_hugr()
if elem_ty.type_bound() == ht.TypeBound.Any:
return self.build_linear_setitem(
array, array_ty, idx, idx_ty, elem, elem_ty, length
)
if self.elem_ty.type_bound() == ht.TypeBound.Any:
return self.build_linear_setitem(array, idx, elem)
else:
return self.build_classical_setitem(
array, array_ty, idx, idx_ty, elem, elem_ty, length
)
return self.build_classical_setitem(array, idx, elem)

def compile(self, args: list[Wire]) -> list[Wire]:
raise InternalGuppyError("Call compile_with_inouts instead")


def build_array_set(
builder: DfBase[ops.DfParentOp],
array: Wire,
array_ty: ht.Type,
idx: Wire,
idx_ty: ht.Type,
elem: Wire,
elem_ty: ht.Type,
length: ht.TypeArg,
) -> tuple[Wire, Wire]:
"""Builds an array set operation, returning the original element."""
sig = ht.FunctionType(
[array_ty, ht.USize(), elem_ty],
[ht.Sum([[elem_ty, array_ty], [elem_ty, array_ty]])],
)
if idx_ty != ht.USize():
idx = builder.add_op(convert_itousize(), idx)
op = ops.ExtOp(
hugr.std.PRELUDE.get_op("set"), sig, [length, ht.TypeTypeArg(elem_ty)]
)
[result] = builder.add_op(op, array, idx, elem)
conditional = builder.add_conditional(result)
with conditional.add_case(0) as case:
error = build_error(case, 1, "array set index out of bounds")
case.set_outputs(
*build_panic(
case, [elem_ty, array_ty], [elem_ty, array_ty], error, *case.inputs()
)
)
with conditional.add_case(1) as case:
case.set_outputs(*case.inputs())
[elem, array] = conditional
return (array, elem)


def build_array_get(
builder: DfBase[ops.DfParentOp],
array: Wire,
array_ty: ht.Type,
idx: Wire,
idx_ty: ht.Type,
elem_ty: ht.Type,
length: ht.TypeArg,
) -> Wire:
"""Builds an array get operation, returning the original element."""
sig = ht.FunctionType([array_ty, ht.USize()], [ht.Sum([[], [elem_ty]])])
op = ops.ExtOp(
hugr.std.PRELUDE.get_op("get"), sig, [length, ht.TypeTypeArg(elem_ty)]
)
if idx_ty != ht.USize():
idx = builder.add_op(convert_itousize(), idx)
[result] = builder.add_op(op, array, idx)
conditional = builder.add_conditional(result)
with conditional.add_case(0) as case:
error = build_error(case, 1, "array get index out of bounds")
case.set_outputs(*build_panic(case, [], [elem_ty], error))
with conditional.add_case(1) as case:
case.set_outputs(*case.inputs())
return conditional


def new_array(length: int, elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that creates a new fixed length array."""
op_def = hugr.std.PRELUDE.get_op("new_array")
sig = ht.FunctionType([elem_ty] * length, [array_type(length, elem_ty)])
return ops.ExtOp(op_def, sig, [ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)])
18 changes: 4 additions & 14 deletions guppylang/prelude/_internal/compiler/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
convert_itousize,
)
from guppylang.prelude._internal.compiler.prelude import (
build_error,
build_panic,
build_unwrap,
build_unwrap_left,
build_unwrap_right,
)
from guppylang.tys.arg import TypeArg
Expand Down Expand Up @@ -191,18 +190,9 @@ def build_linear_setitem(
self.builder, result, "List index out of bounds"
)
# Check that the old element was `None`
conditional = self.builder.add_conditional(old_elem_opt, list_wire)
with conditional.add_case(0) as case:
case.set_outputs(*case.inputs())
with conditional.add_case(1) as case:
# Note: This case can only happen if users manually call `xs.__setitem__`
# since regular indexing `xs[i]` is only allowed in inout position. An error
# in that situation would be a compiler bug!
old_elem, list_wire = case.inputs()
error = build_error(case, 1, "Linear list element has not been used")
build_panic(case, [elem_ty], [], error, old_elem)
case.set_outputs(list_wire)
(list_wire,) = conditional.outputs()
build_unwrap_left(
self.builder, old_elem_opt, "Linear list element has not been used"
)
return CallReturnWires(regular_returns=[], inout_returns=[list_wire])

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
Expand Down
Loading

0 comments on commit 83b9f31

Please sign in to comment.