From d7067356c71cbcc5352e69ea4eed6bdc1d0c1ec8 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 25 Jun 2024 09:08:58 +0100 Subject: [PATCH] feat: Allow constant nats as type args (#255) --- guppylang/checker/core.py | 2 + guppylang/checker/expr_checker.py | 6 +- guppylang/decorator.py | 12 ++- guppylang/definition/parameter.py | 16 +++- guppylang/module.py | 6 +- guppylang/prelude/builtins.py | 7 ++ guppylang/tys/arg.py | 40 ++++++--- guppylang/tys/builtin.py | 29 +++++- guppylang/tys/const.py | 68 ++++++++++++--- guppylang/tys/param.py | 35 ++++++-- guppylang/tys/parsing.py | 15 +++- guppylang/tys/printing.py | 5 ++ guppylang/tys/subst.py | 35 ++++++-- guppylang/tys/ty.py | 92 +++++++++++++++----- tests/error/misc_errors/negative_nat_arg.err | 6 ++ tests/error/misc_errors/negative_nat_arg.py | 7 ++ tests/error/poly_errors/arg_mismatch3.err | 7 ++ tests/error/poly_errors/arg_mismatch3.py | 20 +++++ tests/error/poly_errors/arg_mismatch4.err | 6 ++ tests/error/poly_errors/arg_mismatch4.py | 13 +++ tests/error/poly_errors/arg_mismatch5.err | 6 ++ tests/error/poly_errors/arg_mismatch5.py | 13 +++ tests/integration/test_poly.py | 13 +++ 23 files changed, 386 insertions(+), 73 deletions(-) create mode 100644 tests/error/misc_errors/negative_nat_arg.err create mode 100644 tests/error/misc_errors/negative_nat_arg.py create mode 100644 tests/error/poly_errors/arg_mismatch3.err create mode 100644 tests/error/poly_errors/arg_mismatch3.py create mode 100644 tests/error/poly_errors/arg_mismatch4.err create mode 100644 tests/error/poly_errors/arg_mismatch4.py create mode 100644 tests/error/poly_errors/arg_mismatch5.err create mode 100644 tests/error/poly_errors/arg_mismatch5.py diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 70a74392..e0c2a3cb 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -12,6 +12,7 @@ from guppylang.definition.ty import TypeDef from guppylang.definition.value import CallableDef from guppylang.tys.builtin import ( + array_type_def, bool_type_def, callable_type_def, float_type_def, @@ -76,6 +77,7 @@ def default() -> "Globals": float_type_def, list_type_def, linst_type_def, + array_type_def, ] defs = {defn.id: defn for defn in builtins} names = {defn.name: defn.id for defn in builtins} diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index e0ee56b6..29b37869 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -641,7 +641,7 @@ def check_type_against( "free variables", node, ) - inst = [TypeArg(subst[v]) for v in free_vars] + inst = [subst[v].to_arg() for v in free_vars] subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements @@ -725,7 +725,7 @@ def synthesize_call( # Success implies that the substitution is closed assert all(not t.unsolved_vars for t in subst.values()) - inst = [TypeArg(subst[v]) for v in free_vars] + inst = [subst[v].to_arg() for v in free_vars] # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) @@ -805,7 +805,7 @@ def check_call( # Success implies that the substitution is closed assert all(not t.unsolved_vars for t in subst.values()) - inst = [TypeArg(subst[v]) for v in free_vars] + inst = [subst[v].to_arg() for v in free_vars] subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 2b5f94f5..7585129e 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -21,12 +21,13 @@ from guppylang.definition.declaration import RawFunctionDecl from guppylang.definition.extern import RawExternDef from guppylang.definition.function import RawFunctionDef, parse_py_func -from guppylang.definition.parameter import TypeVarDef +from guppylang.definition.parameter import ConstVarDef, TypeVarDef from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError, MissingModuleError, pretty_errors from guppylang.hugr_builder.hugr import Hugr from guppylang.module import GuppyModule, PyFunc +from guppylang.tys.ty import NumericType FuncDefDecorator = Callable[[PyFunc], RawFunctionDef] FuncDeclDecorator = Callable[[PyFunc], RawFunctionDecl] @@ -172,6 +173,15 @@ def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> Type # that is executed by interpreter before handing it to Guppy. return TypeVar(name) + @pretty_errors + def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef: + """Creates a new const nat variable in a module.""" + defn = ConstVarDef( + DefId.fresh(module), name, None, NumericType(NumericType.Kind.Nat) + ) + module.register_def(defn) + return defn + @pretty_errors def custom( self, diff --git a/guppylang/definition/parameter.py b/guppylang/definition/parameter.py index 270811c1..29597d44 100644 --- a/guppylang/definition/parameter.py +++ b/guppylang/definition/parameter.py @@ -2,7 +2,8 @@ from dataclasses import dataclass, field from guppylang.definition.common import CompiledDef, Definition -from guppylang.tys.param import Parameter, TypeParam +from guppylang.tys.param import ConstParam, Parameter, TypeParam +from guppylang.tys.ty import Type class ParamDef(Definition): @@ -24,3 +25,16 @@ class TypeVarDef(ParamDef, CompiledDef): def to_param(self, idx: int) -> TypeParam: """Creates a parameter from this definition.""" return TypeParam(idx, self.name, self.can_be_linear) + + +@dataclass(frozen=True) +class ConstVarDef(ParamDef, CompiledDef): + """A constant variable definition.""" + + ty: Type + + description: str = field(default="const variable", init=False) + + def to_param(self, idx: int) -> ConstParam: + """Creates a parameter from this definition.""" + return ConstParam(idx, self.name, self.ty) diff --git a/guppylang/module.py b/guppylang/module.py index 513bd658..423299ca 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -19,7 +19,7 @@ from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.declaration import RawFunctionDecl from guppylang.definition.function import RawFunctionDef -from guppylang.definition.parameter import TypeVarDef +from guppylang.definition.parameter import ParamDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError, pretty_errors from guppylang.hugr_builder.hugr import Hugr @@ -86,7 +86,7 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: # For now, we can only import custom functions if any( - not isinstance(v, CustomFunctionDef | TypeDef | TypeVarDef) + not isinstance(v, CustomFunctionDef | TypeDef | ParamDef) for v in m._compiled_globals.values() ): raise GuppyError( @@ -111,7 +111,7 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None: self._instance_func_buffer[defn.name] = defn else: self._check_name_available(defn.name, defn.defined_at) - if isinstance(defn, TypeDef | TypeVarDef): + if isinstance(defn, TypeDef | ParamDef): self._raw_type_defs[defn.id] = defn else: self._raw_defs[defn.id] = defn diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index d5ae7b43..1654466c 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -56,6 +56,13 @@ class nat: """Class to import in order to use nats.""" +class array: + """Class to import in order to use arrays.""" + + def __class_getitem__(cls, item): + return cls + + @guppy.extend_type(builtins, bool_type_def) class Bool: @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) diff --git a/guppylang/tys/arg.py b/guppylang/tys/arg.py index 1b782e6f..9189196c 100644 --- a/guppylang/tys/arg.py +++ b/guppylang/tys/arg.py @@ -4,8 +4,9 @@ from hugr.serialization import tys +from guppylang.error import InternalGuppyError from guppylang.tys.common import ToHugr, Transformable, Transformer, Visitor -from guppylang.tys.const import Const +from guppylang.tys.const import BoundConstVar, Const, ConstValue, ExistentialConstVar from guppylang.tys.var import ExistentialVar if TYPE_CHECKING: @@ -62,27 +63,42 @@ def transform(self, transformer: Transformer) -> Argument: @dataclass(frozen=True) class ConstArg(ArgumentBase): - """Argument that can be substituted for a `ConstParameter`. + """Argument that can be substituted for a `ConstParam`.""" - Note that support for this kind is not implemented yet. - """ - - # Hugr value to instantiate const: Const @property def unsolved_vars(self) -> set[ExistentialVar]: - """The existential type variables contained in this argument.""" - raise NotImplementedError + """The existential const variables contained in this argument.""" + return self.const.unsolved_vars def to_hugr(self) -> tys.TypeArg: - """Computes the Hugr representation of the argument.""" - raise NotImplementedError + """Computes the Hugr representation of this argument.""" + from guppylang.tys.ty import NumericType + + match self.const: + case ConstValue(value=v, ty=NumericType(kind=NumericType.Kind.Nat)): + assert isinstance(v, int) + return tys.TypeArg(tys.BoundedNatArg(n=v)) + case BoundConstVar(idx=idx): + hugr_ty = self.const.ty.to_hugr() + assert isinstance(hugr_ty.root, tys.Opaque) + param = tys.TypeParam(tys.BoundedNatParam(bound=None)) + return tys.TypeArg(tys.VariableArg(idx=idx, cached_decl=param)) + case ConstValue() | BoundConstVar(): + # TODO: Handle other cases besides nats + raise NotImplementedError + case ExistentialConstVar(): + raise InternalGuppyError( + "Tried to convert unsolved constant variable to Hugr" + ) def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this argument.""" - raise NotImplementedError + visitor.visit(self) def transform(self, transformer: Transformer) -> Argument: """Accepts a transformer on this argument.""" - raise NotImplementedError + return transformer.transform(self) or ConstArg( + self.const.transform(transformer) + ) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index aeaa021a..409e36fc 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -8,8 +8,8 @@ from guppylang.definition.common import DefId from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError -from guppylang.tys.arg import Argument, TypeArg -from guppylang.tys.param import TypeParam +from guppylang.tys.arg import Argument, ConstArg, TypeArg +from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( FunctionType, NoneType, @@ -136,6 +136,20 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: return tys.Type(ty) +def _array_to_hugr(args: Sequence[Argument]) -> tys.Type: + # Type checker ensures that we get a two args + [ty_arg, len_arg] = args + assert isinstance(ty_arg, TypeArg) + assert isinstance(len_arg, ConstArg) + ty = tys.Opaque( + extension="prelude", + id="array", + args=[len_arg.to_hugr(), ty_arg.to_hugr()], + bound=ty_arg.ty.hugr_bound, + ) + return tys.Type(ty) + + callable_type_def = _CallableTypeDef(DefId.fresh(), None) tuple_type_def = _TupleTypeDef(DefId.fresh(), None) none_type_def = _NoneTypeDef(DefId.fresh(), None) @@ -172,6 +186,17 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: always_linear=False, to_hugr=_list_to_hugr, ) +array_type_def = OpaqueTypeDef( + id=DefId.fresh(), + name="array", + defined_at=None, + params=[ + TypeParam(0, "T", can_be_linear=True), + ConstParam(1, "n", NumericType(NumericType.Kind.Nat)), + ], + always_linear=False, + to_hugr=_array_to_hugr, +) def bool_type() -> OpaqueType: diff --git a/guppylang/tys/const.py b/guppylang/tys/const.py index 358490cc..d942f417 100644 --- a/guppylang/tys/const.py +++ b/guppylang/tys/const.py @@ -1,18 +1,18 @@ -from abc import ABC +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING - -from hugr.serialization import ops +from typing import TYPE_CHECKING, Any, TypeAlias from guppylang.error import InternalGuppyError +from guppylang.tys.common import Transformable, Transformer, Visitor from guppylang.tys.var import BoundVar, ExistentialVar if TYPE_CHECKING: + from guppylang.tys.arg import ConstArg from guppylang.tys.ty import Type @dataclass(frozen=True) -class Const(ABC): +class ConstBase(Transformable["Const"], ABC): """Abstract base class for constants arguments in the type system. In principle, we can allow constants of any type representable in the type system. @@ -26,22 +26,52 @@ def __post_init__(self) -> None: if self.ty.unsolved_vars: raise InternalGuppyError("Attempted to create constant with unsolved type") + @abstractmethod + def cast(self) -> "Const": + """Casts an implementor of `ConstBase` into a `Const`. + + This enforces that all implementors of `ConstBase` can be embedded into the + `Const` union type. + """ + + @property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this constant.""" + return set() + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this constant.""" + visitor.visit(self) + + def transform(self, transformer: Transformer, /) -> "Const": + """Accepts a transformer on this constant.""" + return transformer.transform(self) or self.cast() + + def to_arg(self) -> "ConstArg": + """Wraps this constant into a type argument.""" + from guppylang.tys.arg import ConstArg + + return ConstArg(self.cast()) + @dataclass(frozen=True) -class ConstValue(Const): +class ConstValue(ConstBase): """A constant value in the type system. For example, in the type `array[int, 5]` the second argument is a `ConstArg` that contains a `ConstValue(5)`. """ - # Hugr encoding of the value - # TODO: We might need a Guppy representation of this... - value: ops.Value + # TODO: We will need a proper Guppy representation of this in the future + value: Any + + def cast(self) -> "Const": + """Casts an implementor of `ConstBase` into a `Const`.""" + return self @dataclass(frozen=True) -class BoundConstVar(BoundVar, Const): +class BoundConstVar(BoundVar, ConstBase): """Bound variable referencing a `ConstParam`. For example, in the function type `forall n: int. array[float, n] -> array[int, n]`, @@ -49,9 +79,13 @@ class BoundConstVar(BoundVar, Const): `BoundConstVar(idx=0)`. """ + def cast(self) -> "Const": + """Casts an implementor of `ConstBase` into a `Const`.""" + return self + @dataclass(frozen=True) -class ExistentialConstVar(ExistentialVar, Const): +class ExistentialConstVar(ExistentialVar, ConstBase): """Existential constant variable. During type checking we try to solve all existential constant variables and @@ -61,3 +95,15 @@ class ExistentialConstVar(ExistentialVar, Const): @classmethod def fresh(cls, display_name: str, ty: "Type") -> "ExistentialConstVar": return ExistentialConstVar(ty, display_name, next(cls._fresh_id)) + + @property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this constant.""" + return {self} + + def cast(self) -> "Const": + """Casts an implementor of `ConstBase` into a `Const`.""" + return self + + +Const: TypeAlias = ConstValue | BoundConstVar | ExistentialConstVar diff --git a/guppylang/tys/param.py b/guppylang/tys/param.py index 0e5a2d11..af521074 100644 --- a/guppylang/tys/param.py +++ b/guppylang/tys/param.py @@ -11,6 +11,7 @@ from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg from guppylang.tys.common import ToHugr +from guppylang.tys.const import BoundConstVar, ExistentialConstVar from guppylang.tys.var import ExistentialVar if TYPE_CHECKING: @@ -131,10 +132,7 @@ def to_hugr(self) -> tys.TypeParam: @dataclass(frozen=True) class ConstParam(ParameterBase): - """A parameter of kind constant. Used to define fixed-size arrays etc. - - Note that support for this kind is not implemented yet. - """ + """A parameter of kind constant. Used to define fixed-size arrays etc.""" ty: "Type" @@ -153,7 +151,17 @@ def check_arg(self, arg: Argument, loc: AstNode | None = None) -> ConstArg: Raises a user error if the argument is not valid. """ - raise NotImplementedError + match arg: + case ConstArg(const): + if const.ty != self.ty: + raise GuppyTypeError( + f"Expected argument of type `{self.ty}`, got {const.ty}", loc + ) + return arg + case TypeArg(): + raise GuppyTypeError( + f"Expected argument of type `{self.ty}`, got type", loc + ) def to_existential(self) -> tuple[Argument, ExistentialVar]: """Creates a fresh existential variable that can be instantiated for this @@ -161,17 +169,28 @@ def to_existential(self) -> tuple[Argument, ExistentialVar]: Returns both the argument and the created variable. """ - raise NotImplementedError + var = ExistentialConstVar.fresh(self.name, self.ty) + return ConstArg(var), var def to_bound(self, idx: int | None = None) -> Argument: """Creates a bound variable with a given index that can be instantiated for this parameter. """ - raise NotImplementedError + if idx is None: + idx = self.idx + return ConstArg(BoundConstVar(self.ty, self.name, idx)) def to_hugr(self) -> tys.TypeParam: """Computes the Hugr representation of the parameter.""" - raise NotImplementedError + from guppylang.tys.ty import NumericType + + match self.ty: + case NumericType(kind=NumericType.Kind.Nat): + return tys.TypeParam(tys.BoundedNatParam(bound=None)) + case _: + hugr_ty = self.ty.to_hugr() + assert isinstance(hugr_ty.root, tys.Opaque) + return tys.TypeParam(tys.OpaqueParam(ty=hugr_ty.root)) def check_all_args( diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index 4154505c..c73222b3 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -10,9 +10,10 @@ from guppylang.definition.parameter import ParamDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError -from guppylang.tys.arg import Argument, TypeArg +from guppylang.tys.arg import Argument, ConstArg, TypeArg +from guppylang.tys.const import ConstValue from guppylang.tys.param import Parameter, TypeParam -from guppylang.tys.ty import NoneType, TupleType, Type +from guppylang.tys.ty import NoneType, NumericType, TupleType, Type def arg_from_ast( @@ -89,6 +90,16 @@ def arg_from_ast( if isinstance(node, ast.Constant) and node.value is None: return TypeArg(NoneType()) + # Integer literals are turned into nat args since these are the only ones we support + # right now. + # TODO: Once we also have int args etc, we need proper inference logic here + if isinstance(node, ast.Constant) and isinstance(node.value, int): + # Fun fact: int ast.Constant values are never negative since e.g. `-5` is a + # `ast.UnaryOp` negation of a `ast.Constant(5)` + assert node.value >= 0 + nat_ty = NumericType(NumericType.Kind.Nat) + return ConstArg(ConstValue(nat_ty, node.value)) + # Finally, we also support delayed annotations in strings if isinstance(node, ast.Constant) and isinstance(node.value, str): try: diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 2d0efb44..7270219a 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -2,6 +2,7 @@ from guppylang.error import InternalGuppyError from guppylang.tys.arg import ConstArg, TypeArg +from guppylang.tys.const import ConstValue from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( FunctionType, @@ -130,6 +131,10 @@ def _visit_TypeArg(self, arg: TypeArg, inside_row: bool) -> str: def _visit_ConstArg(self, arg: ConstArg, inside_row: bool) -> str: return self._visit(arg.const, inside_row) + @_visit.register + def _visit_ConstValue(self, c: ConstValue, inside_row: bool) -> str: + return str(c.value) + def _wrap(s: str, inside_row: bool) -> str: return f"({s})" if inside_row else s diff --git a/guppylang/tys/subst.py b/guppylang/tys/subst.py index 118b3f50..67ae565e 100644 --- a/guppylang/tys/subst.py +++ b/guppylang/tys/subst.py @@ -3,13 +3,19 @@ from typing import Any from guppylang.error import InternalGuppyError -from guppylang.tys.arg import Argument, TypeArg +from guppylang.tys.arg import Argument, ConstArg, TypeArg from guppylang.tys.common import Transformer -from guppylang.tys.const import BoundConstVar, ExistentialConstVar -from guppylang.tys.ty import BoundTypeVar, ExistentialTypeVar, FunctionType, Type +from guppylang.tys.const import BoundConstVar, Const, ConstBase, ExistentialConstVar +from guppylang.tys.ty import ( + BoundTypeVar, + ExistentialTypeVar, + FunctionType, + Type, + TypeBase, +) from guppylang.tys.var import ExistentialVar -Subst = dict[ExistentialVar, Type] # TODO: `GuppyType | Const` or `Argument` ?? +Subst = dict[ExistentialVar, Type | Const] Inst = Sequence[Argument] @@ -25,11 +31,15 @@ def transform(self, ty: Any) -> Any | None: # type: ignore[override] @transform.register def _transform_ExistentialTypeVar(self, ty: ExistentialTypeVar) -> Type | None: - return self.subst.get(ty, None) + s = self.subst.get(ty, None) + assert not isinstance(s, ConstBase) + return s @transform.register - def _transform_ExistentialConstVar(self, ty: ExistentialConstVar) -> Type | None: - raise NotImplementedError + def _transform_ExistentialConstVar(self, c: ExistentialConstVar) -> Const | None: + s = self.subst.get(c, None) + assert not isinstance(s, TypeBase) + return s class Instantiator(Transformer): @@ -54,8 +64,15 @@ def _transform_BoundTypeVar(self, ty: BoundTypeVar) -> Type | None: return BoundTypeVar(ty.display_name, ty.idx - len(self.inst), ty.linear) @transform.register - def _transform_BoundConstVar(self, ty: BoundConstVar) -> Type | None: - raise NotImplementedError + def _transform_BoundConstVar(self, c: BoundConstVar) -> Const | None: + # Instantiate if const value for the index is available + if c.idx < len(self.inst): + arg = self.inst[c.idx] + assert isinstance(arg, ConstArg) + return arg.const + + # Otherwise, lower the de Bruijn index + return BoundConstVar(c.ty, c.display_name, c.idx - len(self.inst)) @transform.register def _transform_FunctionType(self, ty: FunctionType) -> Type | None: diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 8f906846..50a1dcb0 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -11,7 +11,7 @@ from guppylang.error import InternalGuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg from guppylang.tys.common import ToHugr, Transformable, Transformer, Visitor -from guppylang.tys.const import ExistentialConstVar +from guppylang.tys.const import Const, ConstValue, ExistentialConstVar from guppylang.tys.param import Parameter from guppylang.tys.var import BoundVar, ExistentialVar @@ -43,6 +43,14 @@ def hugr_bound(self) -> tys.TypeBound: bound exactly right during serialisation, the Hugr validator will complain. """ + @abstractmethod + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`. + + This enforces that all implementors of `TypeBase` can be embedded into the + `Type` union type. + """ + @cached_property def unsolved_vars(self) -> set[ExistentialVar]: """The existential type variables contained in this type.""" @@ -54,6 +62,10 @@ def substitute(self, subst: "Subst") -> "Type": return self.transform(Substituter(subst)) + def to_arg(self) -> TypeArg: + """Wraps this constant into a type argument.""" + return TypeArg(self.cast()) + def __str__(self) -> str: """Returns a human-readable representation of the type.""" from guppylang.tys.printing import TypePrinter @@ -103,14 +115,7 @@ def linear(self) -> bool: @cached_property def unsolved_vars(self) -> set[ExistentialVar]: """The existential type variables contained in this type.""" - unsolved = set() - for arg in self.args: - match arg: - case TypeArg(ty): - unsolved |= ty.unsolved_vars - case ConstArg(c) if isinstance(c, ExistentialConstVar): - unsolved.add(c) - return unsolved + return set().union(*(arg.unsolved_vars for arg in self.args)) @cached_property def hugr_bound(self) -> tys.TypeBound: @@ -149,6 +154,10 @@ def hugr_bound(self) -> tys.TypeBound: # This is fine since Guppy doesn't use the equatable feature anyways. return TypeBound.Copyable + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" return tys.Type(tys.Variable(i=self.idx, b=self.hugr_bound)) @@ -195,6 +204,10 @@ def hugr_bound(self) -> tys.TypeBound: "Tried to compute bound of unsolved existential type variable" ) + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" raise InternalGuppyError( @@ -222,6 +235,10 @@ class NoneType(TypeBase): # empty rows when generating a Hugr preserve: bool = field(default=False, compare=False) + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" return TupleType([]).to_hugr() @@ -255,6 +272,10 @@ def linear(self) -> bool: """Whether this type should be treated linearly.""" return False + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" match self.kind: @@ -330,6 +351,10 @@ def parametrized(self) -> bool: """Whether the function is parametrized.""" return len(self.params) > 0 + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" if self.parametrized: @@ -423,6 +448,10 @@ def __init__(self, element_types: Sequence["Type"], preserve: bool = False) -> N def intrinsically_linear(self) -> bool: return False + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" # Tuples are encoded as a unary sum. Note that we need to make a copy of this @@ -459,6 +488,10 @@ def __init__(self, element_types: Sequence["Type"]) -> None: def intrinsically_linear(self) -> bool: return False + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" rows = [type_to_row(ty) for ty in self.element_types] @@ -498,6 +531,10 @@ def hugr_bound(self) -> tys.TypeBound: return self.defn.bound return super().hugr_bound + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" return self.defn.to_hugr(self.args) @@ -529,6 +566,10 @@ def intrinsically_linear(self) -> bool: """Whether this type is linear, independent of the arguments.""" return any(f.ty.linear for f in self.defn.fields) + def cast(self) -> "Type": + """Casts an implementor of `TypeBase` into a `Type`.""" + return self + def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" return TupleType([f.ty for f in self.fields]).to_hugr() @@ -591,22 +632,26 @@ def rows_to_hugr(rows: Sequence[TypeRow]) -> list[tys.TypeRow]: return [row_to_hugr(row) for row in rows] -def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None": - """Computes a most general unifier for two types. +def unify(s: Type | Const, t: Type | Const, subst: "Subst | None") -> "Subst | None": + """Computes a most general unifier for two types or constants. Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this not possible. """ + # Make sure that s and t are either both constants or both types + assert isinstance(s, TypeBase) == isinstance(t, TypeBase) if subst is None: return None match s, t: - case ExistentialTypeVar(id=s_id), ExistentialTypeVar(id=t_id) if s_id == t_id: + case ExistentialVar(id=s_id), ExistentialVar(id=t_id) if s_id == t_id: return subst - case ExistentialTypeVar() as s, t: - return _unify_var(s, t, subst) - case s, ExistentialTypeVar() as t: - return _unify_var(t, s, subst) - case BoundTypeVar(idx=s_idx), BoundTypeVar(idx=t_idx) if s_idx == t_idx: + case ExistentialTypeVar() | ExistentialConstVar() as s_var, t: + return _unify_var(s_var, t, subst) + case s, ExistentialTypeVar() | ExistentialConstVar() as t_var: + return _unify_var(t_var, s, subst) + case BoundVar(idx=s_idx), BoundVar(idx=t_idx) if s_idx == t_idx: + return subst + case ConstValue(value=c_value), ConstValue(value=d_value) if c_value == d_value: return subst case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind: return subst @@ -626,8 +671,10 @@ def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None": return None -def _unify_var(var: ExistentialTypeVar, t: Type, subst: "Subst") -> "Subst | None": - """Helper function for unification of type variables.""" +def _unify_var( + var: ExistentialTypeVar | ExistentialConstVar, t: Type | Const, subst: "Subst" +) -> "Subst | None": + """Helper function for unification of type or const variables.""" if var in subst: return unify(subst[var], t, subst) if isinstance(t, ExistentialTypeVar) and t in subst: @@ -650,8 +697,11 @@ def _unify_args( if res is None: return None subst = res - case ConstArg(), ConstArg(): - raise NotImplementedError + case ConstArg(const=sa_const), ConstArg(const=ta_const): + res = unify(sa_const, ta_const, subst) + if res is None: + return None + subst = res case _: return None return subst diff --git a/tests/error/misc_errors/negative_nat_arg.err b/tests/error/misc_errors/negative_nat_arg.err new file mode 100644 index 00000000..94e3b222 --- /dev/null +++ b/tests/error/misc_errors/negative_nat_arg.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @compile_guppy +5: def foo(x: array[int, -5]) -> None: + ^^ +GuppyError: Not a valid type argument diff --git a/tests/error/misc_errors/negative_nat_arg.py b/tests/error/misc_errors/negative_nat_arg.py new file mode 100644 index 00000000..4b20d072 --- /dev/null +++ b/tests/error/misc_errors/negative_nat_arg.py @@ -0,0 +1,7 @@ +from guppylang.prelude.builtins import array +from tests.util import compile_guppy + + +@compile_guppy +def foo(x: array[int, -5]) -> None: + pass diff --git a/tests/error/poly_errors/arg_mismatch3.err b/tests/error/poly_errors/arg_mismatch3.err new file mode 100644 index 00000000..4afde1b9 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch3.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: array[int, 42], y: array[int, 43]) -> None: +17: foo(x, y) + ^ +GuppyTypeError: Expected argument of type `array[int, 42]`, got `array[int, 43]` diff --git a/tests/error/poly_errors/arg_mismatch3.py b/tests/error/poly_errors/arg_mismatch3.py new file mode 100644 index 00000000..e7fc3d97 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch3.py @@ -0,0 +1,20 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + +module = GuppyModule("test") + +n = guppy.nat_var(module, "n") + + +@guppy.declare(module) +def foo(x: array[int, n], y: array[int, n]) -> None: + ... + + +@guppy(module) +def main(x: array[int, 42], y: array[int, 43]) -> None: + foo(x, y) + + +module.compile() diff --git a/tests/error/poly_errors/arg_mismatch4.err b/tests/error/poly_errors/arg_mismatch4.err new file mode 100644 index 00000000..75ec9152 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch4.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy(module) +8: def main(x: array[int, bool]) -> None: + ^^^^^^^^^^^^^^^^ +GuppyTypeError: Expected argument of type `nat`, got type diff --git a/tests/error/poly_errors/arg_mismatch4.py b/tests/error/poly_errors/arg_mismatch4.py new file mode 100644 index 00000000..9ffdf91c --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch4.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + +module = GuppyModule("test") + + +@guppy(module) +def main(x: array[int, bool]) -> None: + pass + + +module.compile() diff --git a/tests/error/poly_errors/arg_mismatch5.err b/tests/error/poly_errors/arg_mismatch5.err new file mode 100644 index 00000000..6928c6d5 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch5.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy(module) +8: def main(x: list[42]) -> None: + ^^^^^^^^ +GuppyTypeError: Expected a type, got value of type nat diff --git a/tests/error/poly_errors/arg_mismatch5.py b/tests/error/poly_errors/arg_mismatch5.py new file mode 100644 index 00000000..1b35ddee --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch5.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy(module) +def main(x: list[42]) -> None: + pass + + +module.compile() diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 98107a32..471f4851 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -4,6 +4,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array from guppylang.prelude.quantum import qubit import guppylang.prelude.quantum as quantum @@ -114,6 +115,18 @@ def main(x: int, y: float) -> float: validate(module.compile()) +def test_nat_args(validate): + module = GuppyModule("test") + n = guppy.nat_var(module, "n") + + @guppy.declare(module) + def foo(x: array[int, n]) -> array[int, n]: ... + + @guppy(module) + def main(x: array[int, 42]) -> array[int, 42]: + return foo(x) + + def test_infer_basic(validate): module = GuppyModule("test") T = guppy.type_var(module, "T")