diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 6fdea55b..0f32ff6e 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -13,7 +13,6 @@ CallableChecker, CoercingChecker, DunderChecker, - FailingChecker, NewArrayChecker, ResultChecker, ReversingChecker, @@ -505,77 +504,35 @@ def __trunc__(self: float) -> float: ... @guppy.extend_type(list_type_def) class List: - @guppy.hugr_op(unsupported_op("Append")) - def __add__(self: list[T], other: list[T]) -> list[T]: ... - - @guppy.hugr_op(unsupported_op("IsEmpty")) - def __bool__(self: list[T]) -> bool: ... - - @guppy.hugr_op(unsupported_op("Contains")) - def __contains__(self: list[T], el: T) -> bool: ... - - @guppy.hugr_op(unsupported_op("AssertEmpty")) - def __end__(self: list[T]) -> None: ... - - @guppy.hugr_op(unsupported_op("Lookup")) - def __getitem__(self: list[T], idx: int) -> T: ... - - @guppy.hugr_op(unsupported_op("IsNotEmpty")) - def __hasnext__(self: list[T]) -> tuple[bool, list[T]]: ... - - @guppy.custom(NoopCompiler()) - def __iter__(self: list[T]) -> list[T]: ... + @guppy.hugr_op(unsupported_op("pop")) # TODO: unwrap and swap None + def __getitem__(self: list[L], idx: int) -> L: ... - @guppy.hugr_op(unsupported_op("Length")) - def __len__(self: list[T]) -> int: ... - - @guppy.hugr_op(unsupported_op("Repeat")) - def __mul__(self: list[T], other: int) -> list[T]: ... + @guppy.hugr_op(unsupported_op("set")) # TODO: check None and unwrap + def __setitem__(self: list[L], idx: int, value: L @ owned) -> None: ... - @guppy.hugr_op(unsupported_op("Pop")) - def __next__(self: list[T]) -> tuple[T, list[T]]: ... + @guppy.hugr_op(unsupported_op("length")) # TODO: inout return in wrong order + def __len__(self: list[L]) -> int: ... @guppy.custom(checker=UnsupportedChecker(), higher_order_value=False) def __new__(x): ... - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def __setitem__(self: list[T], idx: int, value: T) -> None: ... - - @guppy.hugr_op(unsupported_op("Append"), ReversingChecker()) - def __radd__(self: list[T], other: list[T]) -> list[T]: ... - - @guppy.hugr_op(unsupported_op("Repeat")) - def __rmul__(self: list[T], other: int) -> list[T]: ... - - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def append(self: list[T], elt: T) -> None: ... - - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def clear(self: list[T]) -> None: ... - - @guppy.custom(NoopCompiler()) # Can be noop since lists are immutable - def copy(self: list[T]) -> list[T]: ... - - @guppy.hugr_op(unsupported_op("Count")) - def count(self: list[T], elt: T) -> int: ... - - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def extend(self: list[T], seq: None) -> None: ... + @guppy.custom(NoopCompiler()) # TODO: define via Guppy source instead + def __iter__(self: list[L] @ owned) -> list[L]: ... - @guppy.hugr_op(unsupported_op("Find")) - def index(self: list[T], elt: T) -> int: ... + @guppy.hugr_op(unsupported_op("IsNotEmpty")) # TODO + def __hasnext__(self: list[L] @ owned) -> tuple[bool, list[L]]: ... - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def pop(self: list[T], idx: int) -> None: ... + @guppy.hugr_op(unsupported_op("AssertEmpty")) # TODO + def __end__(self: list[L] @ owned) -> None: ... - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def remove(self: list[T], elt: T) -> None: ... + @guppy.hugr_op(unsupported_op("pop")) + def __next__(self: list[L] @ owned) -> tuple[L, list[L]]: ... - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def reverse(self: list[T]) -> None: ... + @guppy.hugr_op(unsupported_op("push")) + def append(self: list[L], item: L @ owned) -> None: ... - @guppy.custom(checker=FailingChecker("Guppy lists are immutable")) - def sort(self: list[T]) -> None: ... + @guppy.hugr_op(unsupported_op("pop")) # TODO + def pop(self: list[L]) -> L: ... linst = list diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index fe40dcfb..87a5d74e 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -103,20 +103,14 @@ def check_instantiate( class _ListTypeDef(OpaqueTypeDef): """Type definition associated with the builtin `list` type. - We have a custom definition to give a nicer error message if the user tries to put - linear data into a regular list. + We have a custom definition to disable usage of lists unless experimental features + are enabled. """ def check_instantiate( self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: check_lists_enabled(loc) - if len(args) == 1: - [arg] = args - if isinstance(arg, TypeArg) and arg.ty.linear: - raise GuppyError( - "Type `list` cannot store linear data, use `linst` instead", loc - ) return super().check_instantiate(args, globals, loc) @@ -192,7 +186,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: id=DefId.fresh(), name="list", defined_at=None, - params=[TypeParam(0, "T", can_be_linear=False)], + params=[TypeParam(0, "T", can_be_linear=True)], always_linear=False, to_hugr=_list_to_hugr, ) diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index ee876a9d..976604c7 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -667,10 +667,11 @@ def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | N case NoneType(), NoneType(): return subst case FunctionType() as s, FunctionType() as t if s.params == t.params: - if len(s.inputs) != len(t.inputs) or any( - a.flags != b.flags for a, b in zip(s.inputs, t.inputs, strict=True) - ): + if len(s.inputs) != len(t.inputs): return None + for a, b in zip(s.inputs, t.inputs, strict=True): + if a.ty.linear and b.ty.linear and a.flags != b.flags: + return None return _unify_args(s, t, subst) case TupleType() as s, TupleType() as t: return _unify_args(s, t, subst) diff --git a/tests/error/misc_errors/list_linear.err b/tests/error/misc_errors/list_linear.err deleted file mode 100644 index f6d3e9b2..00000000 --- a/tests/error/misc_errors/list_linear.err +++ /dev/null @@ -1,6 +0,0 @@ -Guppy compilation failed. Error in file $FILE:13 - -11: @guppy(module) -12: def foo() -> list[qubit]: - ^^^^^^^^^^^ -GuppyError: Type `list` cannot store linear data, use `linst` instead diff --git a/tests/error/misc_errors/list_linear.py b/tests/error/misc_errors/list_linear.py deleted file mode 100644 index b734c7df..00000000 --- a/tests/error/misc_errors/list_linear.py +++ /dev/null @@ -1,17 +0,0 @@ -from guppylang.decorator import guppy -from guppylang.module import GuppyModule -from guppylang.prelude.quantum import qubit - -import guppylang.prelude.quantum as quantum - - -module = GuppyModule("test") -module.load_all(quantum) - - -@guppy(module) -def foo() -> list[qubit]: - return [] - - -module.compile() diff --git a/tests/error/poly_errors/non_linear2.err b/tests/error/poly_errors/non_linear2.err index 98408321..5c6827f2 100644 --- a/tests/error/poly_errors/non_linear2.err +++ b/tests/error/poly_errors/non_linear2.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:23 21: @guppy(module) 22: def main() -> None: 23: foo(h) - ^ -GuppyTypeError: Expected argument of type `?T -> ?T`, got `qubit @owned -> qubit` + ^^^^^^ +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. (T -> T) -> None` with linear type `qubit` diff --git a/tests/error/poly_errors/non_linear3.err b/tests/error/poly_errors/non_linear3.err index 9d433921..d70752e2 100644 --- a/tests/error/poly_errors/non_linear3.err +++ b/tests/error/poly_errors/non_linear3.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: @guppy(module) 24: def main() -> None: 25: foo(h) - ^ -GuppyTypeError: Expected argument of type `?T -> None`, got `qubit -> None` + ^^^^^^ +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. (T -> None) -> None` with linear type `qubit` diff --git a/tests/integration/test_list.py b/tests/integration/test_list.py index 30a1a6d5..9bb57130 100644 --- a/tests/integration/test_list.py +++ b/tests/integration/test_list.py @@ -1,4 +1,6 @@ import pytest +from guppylang import qubit, guppy, GuppyModule +from guppylang.prelude.builtins import owned from tests.util import compile_guppy @@ -29,6 +31,17 @@ def test(x: float) -> list[float]: validate(test) +def test_push_pop(validate): + @compile_guppy + def test(xs: list[int]) -> bool: + xs.append(3) + x = xs.pop() + return x == 3 + + validate(test) + + +@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/528") def test_arith(validate): @compile_guppy def test(xs: list[int]) -> list[int]: @@ -39,10 +52,21 @@ def test(xs: list[int]) -> list[int]: validate(test) -@pytest.mark.skip("Requires updating lists to use inout") def test_subscript(validate): @compile_guppy def test(xs: list[float], i: int) -> float: return xs[2 * i] validate(test) + + +def test_linear(validate): + module = GuppyModule("test") + module.load(qubit) + + @guppy(module) + def test(xs: list[qubit], q: qubit @owned) -> int: + xs.append(q) + return len(xs) + + validate(module.compile())