From 99217dcddb16fa7c713b7e5c5d356715a0fc9496 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 24 Jun 2024 17:03:14 +0100 Subject: [PATCH] feat: Turn int and float into core types (#225) This PR adds a new `NumericType` to the type hierarchy that now represents `int` and `float`. --- guppylang/checker/core.py | 13 +++++++ guppylang/prelude/_internal.py | 63 ++++++++----------------------- guppylang/prelude/builtins.py | 16 +++++--- guppylang/tys/builtin.py | 32 +++++++++++++++- guppylang/tys/printing.py | 5 +++ guppylang/tys/ty.py | 68 +++++++++++++++++++++++++++++++++- 6 files changed, 141 insertions(+), 56 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index cba8930f..72fef4db 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -14,6 +14,8 @@ from guppylang.tys.builtin import ( bool_type_def, callable_type_def, + float_type_def, + int_type_def, linst_type_def, list_type_def, none_type_def, @@ -24,6 +26,7 @@ ExistentialTypeVar, FunctionType, NoneType, + NumericType, OpaqueType, StructType, SumType, @@ -67,6 +70,8 @@ def default() -> "Globals": tuple_type_def, none_type_def, bool_type_def, + int_type_def, + float_type_def, list_type_def, linst_type_def, ] @@ -85,6 +90,14 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None pass case BoundTypeVar() | ExistentialTypeVar() | SumType(): return None + case NumericType(kind): + match kind: + case NumericType.Kind.Int: + type_defn = int_type_def + case NumericType.Kind.Float: + type_defn = float_type_def + case kind: + return assert_never(kind) case FunctionType(): type_defn = callable_type_def case OpaqueType() as ty: diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index d1f806eb..6d02f78f 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -3,7 +3,7 @@ from hugr.serialization import ops, tys from pydantic import BaseModel -from guppylang.ast_util import AstNode, get_type, with_loc, with_type +from guppylang.ast_util import AstNode, get_type, with_loc from guppylang.checker.core import Context from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args from guppylang.definition.custom import ( @@ -12,36 +12,13 @@ CustomFunctionDef, DefaultCallChecker, ) -from guppylang.definition.ty import TypeDef from guppylang.definition.value import CallableDef from guppylang.error import GuppyError, GuppyTypeError from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall from guppylang.tys.builtin import bool_type, list_type from guppylang.tys.subst import Subst -from guppylang.tys.ty import FunctionType, OpaqueType, Type, unify - -INT_WIDTH = 6 # 2^6 = 64 bit - - -hugr_int_type = tys.Type( - tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))], - bound=tys.TypeBound.Eq, - ) -) - - -hugr_float_type = tys.Type( - tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=tys.TypeBound.Copyable, - ) -) +from guppylang.tys.ty import FunctionType, NumericType, Type, unify class ConstInt(BaseModel): @@ -76,9 +53,9 @@ def int_value(i: int) -> ops.Value: return ops.Value( ops.ExtensionValue( extensions=["arithmetic.int.types"], - typ=hugr_int_type, + typ=NumericType(NumericType.Kind.Int).to_hugr(), value=ops.CustomConst( - c="ConstInt", v=ConstInt(log_width=INT_WIDTH, value=i) + c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i) ), ) ) @@ -89,7 +66,7 @@ def float_value(f: float) -> ops.Value: return ops.Value( ops.ExtensionValue( extensions=["arithmetic.float.types"], - typ=hugr_float_type, + typ=NumericType(NumericType.Kind.Float).to_hugr(), value=ops.CustomConst(c="ConstF64", v=ConstF64(value=f)), ) ) @@ -116,16 +93,16 @@ def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType: def int_op( - op_name: str, ext: str = "arithmetic.int", num_params: int = 1 + op_name: str, + ext: str = "arithmetic.int", + args: list[tys.TypeArg] | None = None, + num_params: int = 1, ) -> ops.OpType: """Utility method to create Hugr integer arithmetic ops.""" + if args is None: + args = num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))] return ops.OpType( - ops.CustomOp( - extension=ext, - op_name=op_name, - args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))], - parent=UNDEFINED, - ) + ops.CustomOp(extension=ext, op_name=op_name, args=args, parent=UNDEFINED) ) @@ -140,20 +117,12 @@ class CoercingChecker(DefaultCallChecker): """Function call type checker that automatically coerces arguments to float.""" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - from .builtins import Int - for i in range(len(args)): args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) - if isinstance(ty, OpaqueType) and ty.defn == self.ctx.globals["int"]: - call = with_loc( - self.node, - GlobalCall(def_id=Int.__float__.id, args=[args[i]], type_args=[]), - ) - float_defn = self.ctx.globals["float"] - assert isinstance(float_defn, TypeDef) - args[i] = with_type( - float_defn.check_instantiate([], self.ctx.globals), call - ) + if isinstance(ty, NumericType) and ty.kind != NumericType.Kind.Float: + to_float = self.ctx.globals.get_instance_func(ty, "__float__") + assert to_float is not None + args[i], _ = to_float.synthesize_call([args[i]], self.node, self.ctx) return super().synthesize(args) diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index b013c0d9..03af4d75 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -24,12 +24,16 @@ ReversingChecker, UnsupportedChecker, float_op, - hugr_float_type, - hugr_int_type, int_op, logic_op, ) -from guppylang.tys.builtin import bool_type_def, linst_type_def, list_type_def +from guppylang.tys.builtin import ( + bool_type_def, + float_type_def, + int_type_def, + linst_type_def, + list_type_def, +) builtins = GuppyModule("builtins", import_builtins=False) @@ -64,7 +68,7 @@ def __new__(x): ... def __or__(self: bool, other: bool) -> bool: ... -@guppy.type(builtins, hugr_int_type, name="int") +@guppy.extend_type(builtins, int_type_def) class Int: @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) def __abs__(self: int) -> int: ... @@ -106,7 +110,7 @@ def __gt__(self: int, other: int) -> bool: ... def __int__(self: int) -> int: ... @guppy.hugr_op(builtins, int_op("inot")) - def __invert__(self: int) -> bool: ... + def __invert__(self: int) -> int: ... @guppy.hugr_op(builtins, int_op("ile_s")) def __le__(self: int, other: int) -> bool: ... @@ -203,7 +207,7 @@ def __trunc__(self: int) -> int: ... def __xor__(self: int, other: int) -> int: ... -@guppy.type(builtins, hugr_float_type, name="float", bound=tys.TypeBound.Copyable) +@guppy.extend_type(builtins, float_type_def) class Float: @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) def __abs__(self: float) -> float: ... diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 2c984496..170707d7 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -10,7 +10,14 @@ from guppylang.error import GuppyError from guppylang.tys.arg import Argument, TypeArg from guppylang.tys.param import TypeParam -from guppylang.tys.ty import FunctionType, NoneType, OpaqueType, TupleType, Type +from guppylang.tys.ty import ( + FunctionType, + NoneType, + NumericType, + OpaqueType, + TupleType, + Type, +) if TYPE_CHECKING: from guppylang.checker.core import Globals @@ -79,6 +86,23 @@ def check_instantiate( return NoneType() +@dataclass(frozen=True) +class _NumericTypeDef(TypeDef): + """Type definition associated with the builtin numeric types. + + Any impls on numerics can be registered with these definitions. + """ + + ty: NumericType + + def check_instantiate( + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + ) -> NumericType: + if args: + raise GuppyError(f"Type `{self.name}` is not parameterized", loc) + return self.ty + + @dataclass(frozen=True) class _ListTypeDef(OpaqueTypeDef): """Type definition associated with the builtin `list` type. @@ -123,6 +147,12 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: always_linear=False, to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))), ) +int_type_def = _NumericTypeDef( + DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int) +) +float_type_def = _NumericTypeDef( + DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float) +) linst_type_def = OpaqueTypeDef( id=DefId.fresh(), name="linst", diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 8ac82be6..2d0efb44 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -6,6 +6,7 @@ from guppylang.tys.ty import ( FunctionType, NoneType, + NumericType, OpaqueType, StructType, SumType, @@ -106,6 +107,10 @@ def _visit_SumType(self, ty: SumType, inside_row: bool) -> str: def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str: return "None" + @_visit.register + def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str: + return ty.kind.value + @_visit.register def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str: # TODO: Print linearity? diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 5b176120..75b3700f 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass, field +from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, TypeAlias, cast +from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast from hugr.serialization import tys from hugr.serialization.tys import TypeBound @@ -234,6 +235,65 @@ def transform(self, transformer: Transformer) -> "Type": return transformer.transform(self) or self +@dataclass(frozen=True) +class NumericType(TypeBase): + """Numeric types like `int` and `float`.""" + + kind: "Kind" + + class Kind(Enum): + """The different kinds of numeric types.""" + + Int = "int" + Float = "float" + + INT_WIDTH: ClassVar[int] = 6 + + @property + def linear(self) -> bool: + """Whether this type should be treated linearly.""" + return False + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + match self.kind: + case NumericType.Kind.Int: + return tys.Type( + tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))], + bound=tys.TypeBound.Eq, + ) + ) + case NumericType.Kind.Float: + return tys.Type( + tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, + ) + ) + + @property + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" + match self.kind: + case NumericType.Kind.Float: + return tys.TypeBound.Copyable + case _: + return tys.TypeBound.Eq + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + visitor.visit(self) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or self + + @dataclass(frozen=True, init=False) class FunctionType(ParametrizedTypeBase): """Type of (potentially generic) functions.""" @@ -493,7 +553,9 @@ def transform(self, transformer: Transformer) -> "Type": #: This might become obsolete in case the @sealed decorator is added: #: * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types #: * https://github.com/johnthagen/sealed-typing-pep -Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType +Type: TypeAlias = ( + BoundTypeVar | ExistentialTypeVar | NumericType | NoneType | ParametrizedType +) #: An immutable row of Guppy types. TypeRow: TypeAlias = Sequence[Type] @@ -545,6 +607,8 @@ def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None": return _unify_var(t, s, subst) case BoundTypeVar(idx=s_idx), BoundTypeVar(idx=t_idx) if s_idx == t_idx: return subst + case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind: + return subst case NoneType(), NoneType(): return subst case FunctionType() as s, FunctionType() as t if s.params == t.params: