Skip to content

Commit

Permalink
feat: Add array type (#258)
Browse files Browse the repository at this point in the history
Co-authored-by: Craig Roy <[email protected]>
  • Loading branch information
mark-koch and croyzor authored Jun 25, 2024
1 parent 4dd7ca0 commit 041c621
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 9 deletions.
37 changes: 33 additions & 4 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,26 @@

from guppylang.ast_util import AstNode, get_type, with_loc
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args
from guppylang.checker.expr_checker import (
ExprSynthesizer,
check_call,
check_num_args,
synthesize_call,
)
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.builtin import bool_type, list_type
from guppylang.tys.subst import Subst
from guppylang.tys.arg import ConstArg
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NumericType, Type, unify


Expand Down Expand Up @@ -241,6 +248,28 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return args, subst


class ArrayLenChecker(CustomCallChecker):
"""Function call checker for the `array.__len__` function."""

@staticmethod
def _get_const_len(inst: Inst) -> ast.expr:
"""Helper function to extract the static length from the inferred type args."""
# TODO: This will stop working once we allow generic function defs. Then, the
# argument could also just be variable instead of a concrete number.
match inst:
case [_, ConstArg(const=ConstValue(value=int(n)))]:
return ast.Constant(value=n)
raise InternalGuppyError(f"array.__len__: Invalid instantiation: {inst}")

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
_, _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
return self._get_const_len(inst), int_type()

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
_, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
return self._get_const_len(inst), subst


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""

Expand Down
25 changes: 20 additions & 5 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# mypy: disable-error-code="empty-body, misc, override, valid-type, no-untyped-def"

from typing import Any
from typing import Any, Generic, TypeVar

from hugr.serialization import tys

Expand All @@ -12,6 +12,7 @@
from guppylang.hugr_builder.hugr import DummyOp
from guppylang.module import GuppyModule
from guppylang.prelude._internal import (
ArrayLenChecker,
CallableChecker,
CoercingChecker,
DunderChecker,
Expand All @@ -29,6 +30,7 @@
logic_op,
)
from guppylang.tys.builtin import (
array_type_def,
bool_type_def,
float_type_def,
int_type_def,
Expand Down Expand Up @@ -56,11 +58,12 @@ class nat:
"""Class to import in order to use nats."""


class array:
"""Class to import in order to use arrays."""
_T = TypeVar("_T")
_n = TypeVar("_n")


def __class_getitem__(cls, item):
return cls
class array(Generic[_T, _n]):
"""Class to import in order to use arrays."""


@guppy.extend_type(builtins, bool_type_def)
Expand Down Expand Up @@ -596,6 +599,18 @@ def reverse(self: linst[L]) -> linst[L]: ...
def sort(self: list[T]) -> None: ...


n = guppy.nat_var(builtins, "n")


@guppy.extend_type(builtins, array_type_def)
class Array:
@guppy.hugr_op(builtins, DummyOp("ArrayGet"))
def __getitem__(self: array[T, n], idx: int) -> T: ...

@guppy.custom(builtins, checker=ArrayLenChecker())
def __len__(self: array[T, n]) -> int: ...


@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False)
def abs(x): ...

Expand Down
Empty file.
7 changes: 7 additions & 0 deletions tests/error/array_errors/linear_index.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def main(qs: array[qubit, 42]) -> int:
14: return qs[0]
^^
GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall n, T: nat. (array[T, n], int) -> T` with linear type `qubit`
17 changes: 17 additions & 0 deletions tests/error/array_errors/linear_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main(qs: array[qubit, 42]) -> int:
return qs[0]


module.compile()
7 changes: 7 additions & 0 deletions tests/error/array_errors/linear_len.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def main(qs: array[qubit, 42]) -> int:
14: return len(qs)
^^^^^^^
GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall n, T: nat. array[T, n] -> int` with linear type `qubit`
17 changes: 17 additions & 0 deletions tests/error/array_errors/linear_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def main(qs: array[qubit, 42]) -> int:
return len(qs)


module.compile()
13 changes: 13 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def coerce(x: int, y: float) -> float:
validate(coerce)


def test_nat(validate):
@compile_guppy
def foo(
a: nat, b: nat, c: bool, d: int, e: float
) -> tuple[nat, bool, int, float, float]:
b, c, d, e = nat(b), nat(c), nat(d), nat(e)
x = a + b * c // d - e
y = e / b
return x, bool(x), int(x), float(x), y

validate(foo)


def test_arith_big(validate):
@compile_guppy
def arith(x: int, y: float, z: int) -> bool:
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from hugr.serialization import ops

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude._internal import ConstInt
from guppylang.prelude.builtins import array
from tests.util import compile_guppy


def test_len(validate):
module = GuppyModule("test")

@guppy(module)
def main(xs: array[float, 42]) -> int:
return len(xs)

hg = module.compile()
validate(hg)

[val] = [
node.op.root.v.root
for node in hg.nodes()
if isinstance(node.op.root, ops.Const)
]
assert isinstance(val, ops.ExtensionValue)
assert isinstance(val.value.v, ConstInt)
assert val.value.v.value == 42


def test_index(validate):
@compile_guppy
def main(xs: array[int, 5], i: int) -> int:
return xs[0] + xs[i] + xs[xs[2 * i]]

validate(main)

0 comments on commit 041c621

Please sign in to comment.