Skip to content

Commit

Permalink
feat: Allow constant nats as type args (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored Jun 25, 2024
1 parent a461a9d commit d706735
Show file tree
Hide file tree
Showing 23 changed files with 386 additions and 73 deletions.
2 changes: 2 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef
from guppylang.tys.builtin import (
array_type_def,
bool_type_def,
callable_type_def,
float_type_def,
Expand Down Expand Up @@ -76,6 +77,7 @@ def default() -> "Globals":
float_type_def,
list_type_def,
linst_type_def,
array_type_def,
]
defs = {defn.id: defn for defn in builtins}
names = {defn.name: defn.id for defn in builtins}
Expand Down
6 changes: 3 additions & 3 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def check_type_against(
"free variables",
node,
)
inst = [TypeArg(subst[v]) for v in free_vars]
inst = [subst[v].to_arg() for v in free_vars]
subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars}

# Finally, check that the instantiation respects the linearity requirements
Expand Down Expand Up @@ -725,7 +725,7 @@ def synthesize_call(

# Success implies that the substitution is closed
assert all(not t.unsolved_vars for t in subst.values())
inst = [TypeArg(subst[v]) for v in free_vars]
inst = [subst[v].to_arg() for v in free_vars]

# Finally, check that the instantiation respects the linearity requirements
check_inst(func_ty, inst, node)
Expand Down Expand Up @@ -805,7 +805,7 @@ def check_call(

# Success implies that the substitution is closed
assert all(not t.unsolved_vars for t in subst.values())
inst = [TypeArg(subst[v]) for v in free_vars]
inst = [subst[v].to_arg() for v in free_vars]
subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}

# Finally, check that the instantiation respects the linearity requirements
Expand Down
12 changes: 11 additions & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.extern import RawExternDef
from guppylang.definition.function import RawFunctionDef, parse_py_func
from guppylang.definition.parameter import TypeVarDef
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.hugr_builder.hugr import Hugr
from guppylang.module import GuppyModule, PyFunc
from guppylang.tys.ty import NumericType

FuncDefDecorator = Callable[[PyFunc], RawFunctionDef]
FuncDeclDecorator = Callable[[PyFunc], RawFunctionDecl]
Expand Down Expand Up @@ -172,6 +173,15 @@ def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> Type
# that is executed by interpreter before handing it to Guppy.
return TypeVar(name)

@pretty_errors
def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef:
"""Creates a new const nat variable in a module."""
defn = ConstVarDef(
DefId.fresh(module), name, None, NumericType(NumericType.Kind.Nat)
)
module.register_def(defn)
return defn

@pretty_errors
def custom(
self,
Expand Down
16 changes: 15 additions & 1 deletion guppylang/definition/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass, field

from guppylang.definition.common import CompiledDef, Definition
from guppylang.tys.param import Parameter, TypeParam
from guppylang.tys.param import ConstParam, Parameter, TypeParam
from guppylang.tys.ty import Type


class ParamDef(Definition):
Expand All @@ -24,3 +25,16 @@ class TypeVarDef(ParamDef, CompiledDef):
def to_param(self, idx: int) -> TypeParam:
"""Creates a parameter from this definition."""
return TypeParam(idx, self.name, self.can_be_linear)


@dataclass(frozen=True)
class ConstVarDef(ParamDef, CompiledDef):
"""A constant variable definition."""

ty: Type

description: str = field(default="const variable", init=False)

def to_param(self, idx: int) -> ConstParam:
"""Creates a parameter from this definition."""
return ConstParam(idx, self.name, self.ty)
6 changes: 3 additions & 3 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.function import RawFunctionDef
from guppylang.definition.parameter import TypeVarDef
from guppylang.definition.parameter import ParamDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.hugr_builder.hugr import Hugr
Expand Down Expand Up @@ -86,7 +86,7 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None:

# For now, we can only import custom functions
if any(
not isinstance(v, CustomFunctionDef | TypeDef | TypeVarDef)
not isinstance(v, CustomFunctionDef | TypeDef | ParamDef)
for v in m._compiled_globals.values()
):
raise GuppyError(
Expand All @@ -111,7 +111,7 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
self._instance_func_buffer[defn.name] = defn
else:
self._check_name_available(defn.name, defn.defined_at)
if isinstance(defn, TypeDef | TypeVarDef):
if isinstance(defn, TypeDef | ParamDef):
self._raw_type_defs[defn.id] = defn
else:
self._raw_defs[defn.id] = defn
Expand Down
7 changes: 7 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ class nat:
"""Class to import in order to use nats."""


class array:
"""Class to import in order to use arrays."""

def __class_getitem__(cls, item):
return cls


@guppy.extend_type(builtins, bool_type_def)
class Bool:
@guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]))
Expand Down
40 changes: 28 additions & 12 deletions guppylang/tys/arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from hugr.serialization import tys

from guppylang.error import InternalGuppyError
from guppylang.tys.common import ToHugr, Transformable, Transformer, Visitor
from guppylang.tys.const import Const
from guppylang.tys.const import BoundConstVar, Const, ConstValue, ExistentialConstVar
from guppylang.tys.var import ExistentialVar

if TYPE_CHECKING:
Expand Down Expand Up @@ -62,27 +63,42 @@ def transform(self, transformer: Transformer) -> Argument:

@dataclass(frozen=True)
class ConstArg(ArgumentBase):
"""Argument that can be substituted for a `ConstParameter`.
"""Argument that can be substituted for a `ConstParam`."""

Note that support for this kind is not implemented yet.
"""

# Hugr value to instantiate
const: Const

@property
def unsolved_vars(self) -> set[ExistentialVar]:
"""The existential type variables contained in this argument."""
raise NotImplementedError
"""The existential const variables contained in this argument."""
return self.const.unsolved_vars

def to_hugr(self) -> tys.TypeArg:
"""Computes the Hugr representation of the argument."""
raise NotImplementedError
"""Computes the Hugr representation of this argument."""
from guppylang.tys.ty import NumericType

match self.const:
case ConstValue(value=v, ty=NumericType(kind=NumericType.Kind.Nat)):
assert isinstance(v, int)
return tys.TypeArg(tys.BoundedNatArg(n=v))
case BoundConstVar(idx=idx):
hugr_ty = self.const.ty.to_hugr()
assert isinstance(hugr_ty.root, tys.Opaque)
param = tys.TypeParam(tys.BoundedNatParam(bound=None))
return tys.TypeArg(tys.VariableArg(idx=idx, cached_decl=param))
case ConstValue() | BoundConstVar():
# TODO: Handle other cases besides nats
raise NotImplementedError
case ExistentialConstVar():
raise InternalGuppyError(
"Tried to convert unsolved constant variable to Hugr"
)

def visit(self, visitor: Visitor) -> None:
"""Accepts a visitor on this argument."""
raise NotImplementedError
visitor.visit(self)

def transform(self, transformer: Transformer) -> Argument:
"""Accepts a transformer on this argument."""
raise NotImplementedError
return transformer.transform(self) or ConstArg(
self.const.transform(transformer)
)
29 changes: 27 additions & 2 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, TypeArg
from guppylang.tys.param import TypeParam
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.ty import (
FunctionType,
NoneType,
Expand Down Expand Up @@ -136,6 +136,20 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
return tys.Type(ty)


def _array_to_hugr(args: Sequence[Argument]) -> tys.Type:
# Type checker ensures that we get a two args
[ty_arg, len_arg] = args
assert isinstance(ty_arg, TypeArg)
assert isinstance(len_arg, ConstArg)
ty = tys.Opaque(
extension="prelude",
id="array",
args=[len_arg.to_hugr(), ty_arg.to_hugr()],
bound=ty_arg.ty.hugr_bound,
)
return tys.Type(ty)


callable_type_def = _CallableTypeDef(DefId.fresh(), None)
tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
none_type_def = _NoneTypeDef(DefId.fresh(), None)
Expand Down Expand Up @@ -172,6 +186,17 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
always_linear=False,
to_hugr=_list_to_hugr,
)
array_type_def = OpaqueTypeDef(
id=DefId.fresh(),
name="array",
defined_at=None,
params=[
TypeParam(0, "T", can_be_linear=True),
ConstParam(1, "n", NumericType(NumericType.Kind.Nat)),
],
always_linear=False,
to_hugr=_array_to_hugr,
)


def bool_type() -> OpaqueType:
Expand Down
68 changes: 57 additions & 11 deletions guppylang/tys/const.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

from hugr.serialization import ops
from typing import TYPE_CHECKING, Any, TypeAlias

from guppylang.error import InternalGuppyError
from guppylang.tys.common import Transformable, Transformer, Visitor
from guppylang.tys.var import BoundVar, ExistentialVar

if TYPE_CHECKING:
from guppylang.tys.arg import ConstArg
from guppylang.tys.ty import Type


@dataclass(frozen=True)
class Const(ABC):
class ConstBase(Transformable["Const"], ABC):
"""Abstract base class for constants arguments in the type system.
In principle, we can allow constants of any type representable in the type system.
Expand All @@ -26,32 +26,66 @@ def __post_init__(self) -> None:
if self.ty.unsolved_vars:
raise InternalGuppyError("Attempted to create constant with unsolved type")

@abstractmethod
def cast(self) -> "Const":
"""Casts an implementor of `ConstBase` into a `Const`.
This enforces that all implementors of `ConstBase` can be embedded into the
`Const` union type.
"""

@property
def unsolved_vars(self) -> set[ExistentialVar]:
"""The existential type variables contained in this constant."""
return set()

def visit(self, visitor: Visitor) -> None:
"""Accepts a visitor on this constant."""
visitor.visit(self)

def transform(self, transformer: Transformer, /) -> "Const":
"""Accepts a transformer on this constant."""
return transformer.transform(self) or self.cast()

def to_arg(self) -> "ConstArg":
"""Wraps this constant into a type argument."""
from guppylang.tys.arg import ConstArg

return ConstArg(self.cast())


@dataclass(frozen=True)
class ConstValue(Const):
class ConstValue(ConstBase):
"""A constant value in the type system.
For example, in the type `array[int, 5]` the second argument is a `ConstArg` that
contains a `ConstValue(5)`.
"""

# Hugr encoding of the value
# TODO: We might need a Guppy representation of this...
value: ops.Value
# TODO: We will need a proper Guppy representation of this in the future
value: Any

def cast(self) -> "Const":
"""Casts an implementor of `ConstBase` into a `Const`."""
return self


@dataclass(frozen=True)
class BoundConstVar(BoundVar, Const):
class BoundConstVar(BoundVar, ConstBase):
"""Bound variable referencing a `ConstParam`.
For example, in the function type `forall n: int. array[float, n] -> array[int, n]`,
we represent the int argument to `array` as a `ConstArg` containing a
`BoundConstVar(idx=0)`.
"""

def cast(self) -> "Const":
"""Casts an implementor of `ConstBase` into a `Const`."""
return self


@dataclass(frozen=True)
class ExistentialConstVar(ExistentialVar, Const):
class ExistentialConstVar(ExistentialVar, ConstBase):
"""Existential constant variable.
During type checking we try to solve all existential constant variables and
Expand All @@ -61,3 +95,15 @@ class ExistentialConstVar(ExistentialVar, Const):
@classmethod
def fresh(cls, display_name: str, ty: "Type") -> "ExistentialConstVar":
return ExistentialConstVar(ty, display_name, next(cls._fresh_id))

@property
def unsolved_vars(self) -> set[ExistentialVar]:
"""The existential type variables contained in this constant."""
return {self}

def cast(self) -> "Const":
"""Casts an implementor of `ConstBase` into a `Const`."""
return self


Const: TypeAlias = ConstValue | BoundConstVar | ExistentialConstVar
Loading

0 comments on commit d706735

Please sign in to comment.