From eca9dc62182c027cc62cd760c98bd7f973067ec2 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 10 Dec 2024 12:20:10 +0000 Subject: [PATCH 1/2] feat: Implicit coercion of numeric types --- guppylang/checker/expr_checker.py | 36 +++++++++++++++++++++++----- guppylang/definition/custom.py | 2 +- guppylang/std/_internal/checker.py | 7 ++++-- guppylang/tys/printing.py | 2 +- guppylang/tys/ty.py | 12 ++++++---- tests/integration/test_arithmetic.py | 11 +++++++++ 6 files changed, 56 insertions(+), 14 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index ca5f03be..61e24674 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -121,6 +121,7 @@ FunctionType, InputFlags, NoneType, + NumericType, OpaqueType, StructType, TupleType, @@ -207,7 +208,7 @@ def check( # If we already have a type for the expression, we just have to match it against # the target if actual := get_type_opt(expr): - subst, inst = check_type_against(actual, ty, expr, kind) + expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind) if inst: expr = with_loc(expr, TypeApply(value=expr, tys=inst)) return with_type(ty.substitute(subst), expr), subst @@ -329,7 +330,7 @@ def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]: def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]: # Try to synthesize and then check if we can unify it with the given type node, synth = self._synthesize(node, allow_free_vars=False) - subst, inst = check_type_against(synth, ty, node, self._kind) + node, subst, inst = check_type_against(synth, ty, node, self.ctx, self._kind) # Apply instantiation of quantified type variables if inst: @@ -759,8 +760,8 @@ def generic_visit(self, node: ast.expr) -> NoReturn: def check_type_against( - act: Type, exp: Type, node: AstNode, kind: str = "expression" -) -> tuple[Subst, Inst]: + act: Type, exp: Type, node: ast.expr, ctx: Context, kind: str = "expression" +) -> tuple[ast.expr, Subst, Inst]: """Checks a type against another type. Returns a substitution for the free variables the expected type and an instantiation @@ -797,14 +798,37 @@ def check_type_against( # Finally, check that the instantiation respects the linearity requirements check_inst(act, inst, node) - return subst, inst + return node, subst, inst # Otherwise, we know that `act` has no unsolved type vars, so unification is trivial assert not act.unsolved_vars subst = unify(exp, act, {}) if subst is None: + # Maybe we can implicitly coerce `act` to `exp` + if coerced := try_coerce_to(act, exp, node, ctx): + return coerced, {}, [] raise GuppyTypeError(TypeMismatchError(node, exp, act, kind)) - return subst, [] + return node, subst, [] + + +def try_coerce_to( + act: Type, exp: Type, node: ast.expr, ctx: Context +) -> ast.expr | None: + """Tries to implicitly coerce an expression to a different type. + + Returns the coerced expression or `None` if the type cannot be implicitly coerced. + """ + # Currently, we only support implicit coercions of numeric types + if not isinstance(act, NumericType) or not isinstance(exp, NumericType): + return None + # Ordering on `NumericType.Kind` defines the coercion relation + if act.kind < exp.kind: + f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__") + assert f is not None + node, subst = f.check_call([node], exp, node, ctx) + assert len(subst) == 0, "Coercion methods are not generic" + return node + return None def check_num_args( diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index 1279d26c..c456b832 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -308,7 +308,7 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: from guppylang.checker.expr_checker import check_type_against expr, res_ty = self.synthesize(args) - subst, _ = check_type_against(res_ty, ty, self.node) + expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx) return expr, subst @abstractmethod diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index 3487585e..6aac4e71 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -232,7 +232,10 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: # TODO: We could use the type information to infer some stuff # in the comprehension arr_compr, res_ty = self.synthesize_array_comprehension(compr) - subst, _ = check_type_against(res_ty, ty, self.node) + arr_compr = with_loc(self.node, arr_compr) + arr_compr, subst, _ = check_type_against( + res_ty, ty, arr_compr, self.ctx + ) return arr_compr, subst # Or a list of array elements case args: @@ -359,7 +362,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: expr, res_ty = self.synthesize(args) - subst, _ = check_type_against(res_ty, ty, self.node) + expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx) return expr, subst @staticmethod diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 1e2cde16..633039fd 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -125,7 +125,7 @@ def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str: @_visit.register def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str: - return ty.kind.value + return ty.kind.name.lower() @_visit.register def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str: diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 41f91bcc..1daad256 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from enum import Enum, Flag, auto -from functools import cached_property +from functools import cached_property, total_ordering from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast import hugr.std.float @@ -259,12 +259,16 @@ class NumericType(TypeBase): kind: "Kind" + @total_ordering class Kind(Enum): """The different kinds of numeric types.""" - Nat = "nat" - Int = "int" - Float = "float" + Nat = auto() + Int = auto() + Float = auto() + + def __lt__(self, other: "NumericType.Kind") -> bool: + return self.value < other.value INT_WIDTH: ClassVar[int] = 6 diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 4a10312a..33d03914 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -117,6 +117,17 @@ def main(a1: angle, a2: angle) -> bool: validate(module.compile()) +def test_implicit_coercion(validate): + @compile_guppy + def coerce(x: nat) -> float: + y: int = x + z: float = y + a: float = 1 + return z + a + + validate(coerce) + + def test_angle_float_coercion(validate): module = GuppyModule("test") module.load(angle) From 49b609592038eb172646cfb827f66df9065105de Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 10 Dec 2024 13:11:19 +0000 Subject: [PATCH 2/2] chore: Define arithmetic methods via Guppy source instead of custom compilers --- guppylang/std/_internal/checker.py | 52 ++--- .../std/_internal/compiler/arithmetic.py | 218 +----------------- guppylang/std/builtins.py | 191 +++++++-------- 3 files changed, 115 insertions(+), 346 deletions(-) diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index 6aac4e71..4166cded 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -4,7 +4,7 @@ from typing_extensions import assert_never -from guppylang.ast_util import AstNode, with_loc, with_type +from guppylang.ast_util import with_loc, with_type from guppylang.checker.core import Context from guppylang.checker.errors.generic import ExpectedError, UnsupportedError from guppylang.checker.errors.type_errors import ( @@ -22,8 +22,6 @@ ) from guppylang.definition.custom import ( CustomCallChecker, - CustomFunctionDef, - DefaultCallChecker, ) from guppylang.definition.struct import CheckedStructDef, RawStructDef from guppylang.diagnostic import Error, Note @@ -60,42 +58,28 @@ ) -class CoercingChecker(DefaultCallChecker): - """Function call type checker that automatically coerces arguments to float.""" - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - for i in range(len(args)): - args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) - if isinstance(ty, NumericType) and ty.kind != NumericType.Kind.Float: - to_float = self.ctx.globals.get_instance_func(ty, "__float__") - assert to_float is not None - args[i], _ = to_float.synthesize_call([args[i]], self.node, self.ctx) - return super().synthesize(args) - - class ReversingChecker(CustomCallChecker): - """Call checker that reverses the arguments after checking.""" - - base_checker: CustomCallChecker + """Call checker for reverse arithmetic methods. - def __init__(self, base_checker: CustomCallChecker | None = None): - self.base_checker = base_checker or DefaultCallChecker() - - def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None: - super()._setup(ctx, node, func) - self.base_checker._setup(ctx, node, func) + For examples, turns a call to `__radd__` into a call to `__add__` with reversed + arguments. + """ - def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: - expr, subst = self.base_checker.check(args, ty) - if isinstance(expr, GlobalCall): - expr.args = list(reversed(expr.args)) - return expr, subst + def parse_name(self) -> str: + # Must be a dunder method + assert self.func.name.startswith("__") + assert self.func.name.endswith("__") + name = self.func.name[2:-2] + # Remove the `r` + assert name.startswith("r") + return f"__{name[1:]}__" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - expr, ty = self.base_checker.synthesize(args) - if isinstance(expr, GlobalCall): - expr.args = list(reversed(expr.args)) - return expr, ty + [self_arg, other_arg] = args + self_arg, self_ty = ExprSynthesizer(self.ctx).synthesize(self_arg) + f = self.ctx.globals.get_instance_func(self_ty, self.parse_name()) + assert f is not None + return f.synthesize_call([other_arg, self_arg], self.node, self.ctx) class UnsupportedChecker(CustomCallChecker): diff --git a/guppylang/std/_internal/compiler/arithmetic.py b/guppylang/std/_internal/compiler/arithmetic.py index f1202e97..18d06bd5 100644 --- a/guppylang/std/_internal/compiler/arithmetic.py +++ b/guppylang/std/_internal/compiler/arithmetic.py @@ -3,14 +3,10 @@ from collections.abc import Sequence import hugr.std.int -from hugr import Wire, ops +from hugr import ops from hugr import tys as ht -from hugr.std.float import FLOAT_T from hugr.std.int import int_t -from guppylang.definition.custom import ( - CustomCallCompiler, -) from guppylang.tys.ty import NumericType INT_T = int_t(NumericType.INT_WIDTH) @@ -92,215 +88,3 @@ def convert_ifrombool() -> ops.ExtOp: def convert_itobool() -> ops.ExtOp: """Returns a `std.arithmetic.conversions.itobool` operation.""" return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool]) - - -# ------------------------------------------------------ -# --------- Custom compilers for non-native ops -------- -# ------------------------------------------------------ - - -class NatTruedivCompiler(CustomCallCompiler): - """Compiler for the `nat.__truediv__` method.""" - - def compile(self, args: list[Wire]) -> list[Wire]: - from guppylang.std.builtins import Float, Nat - - # Compile `truediv` using float arithmetic - [left, right] = args - [left] = Nat.__float__.compile_call( - [left], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([INT_T], [FLOAT_T]), - ) - [right] = Nat.__float__.compile_call( - [right], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([INT_T], [FLOAT_T]), - ) - [out] = Float.__truediv__.compile_call( - [left, right], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), - ) - return [out] - - -class IntTruedivCompiler(CustomCallCompiler): - """Compiler for the `int.__truediv__` method.""" - - def compile(self, args: list[Wire]) -> list[Wire]: - from guppylang.std.builtins import Float, Int - - # Compile `truediv` using float arithmetic - [left, right] = args - [left] = Int.__float__.compile_call( - [left], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([INT_T], [FLOAT_T]), - ) - [right] = Int.__float__.compile_call( - [right], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([INT_T], [FLOAT_T]), - ) - [out] = Float.__truediv__.compile_call( - [left, right], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), - ) - return [out] - - -class FloatBoolCompiler(CustomCallCompiler): - """Compiler for the `float.__bool__` method.""" - - def compile(self, args: list[Wire]) -> list[Wire]: - from guppylang.std.builtins import Float - - # We have: bool(x) = (x != 0.0) - zero = self.builder.load(hugr.std.float.FloatVal(0.0)) - [out] = Float.__ne__.compile_call( - [args[0], zero], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [ht.Bool]), - ) - return [out] - - -class FloatFloordivCompiler(CustomCallCompiler): - """Compiler for the `float.__floordiv__` method.""" - - def compile(self, args: list[Wire]) -> list[Wire]: - from guppylang.std.builtins import Float - - # We have: floordiv(x, y) = floor(truediv(x, y)) - [div] = Float.__truediv__.compile_call( - args, - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), - ) - [floor] = Float.__floor__.compile_call( - [div], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T], [FLOAT_T]), - ) - return [floor] - - -class FloatModCompiler(CustomCallCompiler): - """Compiler for the `float.__mod__` method.""" - - def compile(self, args: list[Wire]) -> list[Wire]: - from guppylang.std.builtins import Float - - # We have: mod(x, y) = x - (x // y) * y - [div] = Float.__floordiv__.compile_call( - args, - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]), - ) - [mul] = Float.__mul__.compile_call( - [div, args[1]], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), - ) - [sub] = Float.__sub__.compile_call( - [args[0], mul], - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), - ) - return [sub] - - -class FloatDivmodCompiler(CustomCallCompiler): - """Compiler for the `__divmod__` method.""" - - def compile(self, args: list[Wire]) -> list[Wire]: - from guppylang.std.builtins import Float - - # We have: divmod(x, y) = (div(x, y), mod(x, y)) - [div] = Float.__truediv__.compile_call( - args, - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T, FLOAT_T], [FLOAT_T]), - ) - [mod] = Float.__mod__.compile_call( - args, - [], - self.dfg, - self.globals, - self.node, - ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]), - ) - return list(self.builder.add(ops.MakeTuple()(div, mod))) - - -class IToBoolCompiler(CustomCallCompiler): - """Compiler for the `Int` and `Nat` `.__bool__` methods. - - Note that the native `std.arithmetic.conversions.itobool` hugr op - only supports 1 bit integers as input. - """ - - def compile(self, args: list[Wire]) -> list[Wire]: - # Emit a comparison against zero - [num] = args - zero = self.builder.load(hugr.std.int.IntVal(0, width=6)) - out = self.builder.add_op(ine(NumericType.INT_WIDTH), num, zero) - return [out] - - -class IFromBoolCompiler(CustomCallCompiler): - """Compiler for the `Bool` `.__int__` and `.__nat__` methods. - - Note that the native `std.arithmetic.conversions.ifrombool` hugr op - only produces 1 bit integers as output, so we have to widen the result. - """ - - def compile(self, args: list[Wire]) -> list[Wire]: - # Emit an `ifrombool` followed by a widening cast - # We use `widen_u` independently of the target type, since we want the bit `1` - # to be expanded to `0x00000001` even for `nat` types - [boolean] = args - bit = self.builder.add_op(convert_ifrombool(), boolean) - num = self.builder.add_op(iwiden_u(0, NumericType.INT_WIDTH), bit) - return [num] diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index 46a09ee8..a4027cd3 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -7,11 +7,10 @@ import hugr.std.int from guppylang.decorator import guppy -from guppylang.definition.custom import DefaultCallChecker, NoopCompiler +from guppylang.definition.custom import NoopCompiler from guppylang.std._internal.checker import ( ArrayLenChecker, CallableChecker, - CoercingChecker, DunderChecker, NewArrayChecker, RangeChecker, @@ -19,16 +18,6 @@ ReversingChecker, UnsupportedChecker, ) -from guppylang.std._internal.compiler.arithmetic import ( - FloatBoolCompiler, - FloatDivmodCompiler, - FloatFloordivCompiler, - FloatModCompiler, - IFromBoolCompiler, - IntTruedivCompiler, - IToBoolCompiler, - NatTruedivCompiler, -) from guppylang.std._internal.compiler.array import ( ArrayGetitemCompiler, ArrayIterEndCompiler, @@ -109,11 +98,16 @@ def __bool__(self: bool) -> bool: ... @guppy.hugr_op(logic_op("Eq")) def __eq__(self: bool, other: bool) -> bool: ... - @guppy.custom(IFromBoolCompiler()) - def __int__(self: bool) -> int: ... + @guppy + @no_type_check + def __int__(self: bool) -> int: + return 1 if self else 0 - @guppy.custom(IFromBoolCompiler()) - def __nat__(self: bool) -> nat: ... + @guppy + @no_type_check + def __nat__(self: bool) -> nat: + # TODO: Literals should check against `nat` + return nat(1) if self else nat(0) @guppy.custom(checker=DunderChecker("__bool__"), higher_order_value=False) def __new__(x): ... @@ -136,8 +130,10 @@ def __add__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(int_op("iand")) def __and__(self: nat, other: nat) -> nat: ... - @guppy.custom(IToBoolCompiler()) - def __bool__(self: nat) -> bool: ... + @guppy + @no_type_check + def __bool__(self: nat) -> bool: + return self != 0 @guppy.custom(NoopCompiler()) def __ceil__(self: nat) -> nat: ... @@ -202,56 +198,58 @@ def __pos__(self: nat) -> nat: ... @guppy.hugr_op(int_op("ipow")) def __pow__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("iadd"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __radd__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("iand"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rand__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("idivmod_u"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rdivmod__(self: nat, other: nat) -> tuple[nat, nat]: ... - @guppy.hugr_op(int_op("idiv_u"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rfloordiv__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("ishl"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rlshift__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("imod_u"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rmod__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("imul"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rmul__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("ior"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __ror__(self: nat, other: nat) -> nat: ... @guppy.custom(NoopCompiler()) def __round__(self: nat) -> nat: ... - @guppy.hugr_op(int_op("ipow"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rpow__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("ishr"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rrshift__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(int_op("ishr")) def __rshift__(self: nat, other: nat) -> nat: ... - @guppy.hugr_op(int_op("isub"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rsub__(self: nat, other: nat) -> nat: ... - @guppy.custom(NatTruedivCompiler(), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rtruediv__(self: nat, other: nat) -> float: ... - @guppy.hugr_op(int_op("ixor"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rxor__(self: nat, other: nat) -> nat: ... @guppy.hugr_op(int_op("isub")) def __sub__(self: nat, other: nat) -> nat: ... - @guppy.custom(NatTruedivCompiler()) - def __truediv__(self: nat, other: nat) -> float: ... + @guppy + @no_type_check + def __truediv__(self: nat, other: nat) -> float: + return float(self) / float(other) @guppy.custom(NoopCompiler()) def __trunc__(self: nat) -> nat: ... @@ -271,8 +269,10 @@ def __add__(self: int, other: int) -> int: ... @guppy.hugr_op(int_op("iand")) def __and__(self: int, other: int) -> int: ... - @guppy.custom(IToBoolCompiler()) - def __bool__(self: int) -> bool: ... + @guppy + @no_type_check + def __bool__(self: int) -> bool: + return self != 0 @guppy.custom(NoopCompiler()) def __ceil__(self: int) -> int: ... @@ -340,56 +340,58 @@ def __pos__(self: int) -> int: ... @guppy.hugr_op(int_op("ipow")) def __pow__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("iadd"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __radd__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("iand"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rand__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("idivmod_s"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rdivmod__(self: int, other: int) -> tuple[int, int]: ... - @guppy.hugr_op(int_op("idiv_s"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rfloordiv__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("ishl"), ReversingChecker()) # TODO: RHS is unsigned + @guppy.custom(checker=ReversingChecker()) # TODO: RHS is unsigned def __rlshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("imod_s"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rmod__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("imul"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rmul__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("ior"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __ror__(self: int, other: int) -> int: ... @guppy.custom(NoopCompiler()) def __round__(self: int) -> int: ... - @guppy.hugr_op(int_op("ipow"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rpow__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("ishr"), ReversingChecker()) # TODO: RHS is unsigned + @guppy.custom(checker=ReversingChecker()) # TODO: RHS is unsigned def __rrshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("ishr")) # TODO: RHS is unsigned + @guppy.custom(checker=ReversingChecker()) # TODO: RHS is unsigned def __rshift__(self: int, other: int) -> int: ... - @guppy.hugr_op(int_op("isub"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rsub__(self: int, other: int) -> int: ... - @guppy.custom(IntTruedivCompiler(), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rtruediv__(self: int, other: int) -> float: ... - @guppy.hugr_op(int_op("ixor"), ReversingChecker()) + @guppy.custom(checker=ReversingChecker()) def __rxor__(self: int, other: int) -> int: ... @guppy.hugr_op(int_op("isub")) def __sub__(self: int, other: int) -> int: ... - @guppy.custom(IntTruedivCompiler()) - def __truediv__(self: int, other: int) -> float: ... + @guppy + @no_type_check + def __truediv__(self: int, other: int) -> float: + return float(self) / float(other) @guppy.custom(NoopCompiler()) def __trunc__(self: int) -> int: ... @@ -400,115 +402,114 @@ def __xor__(self: int, other: int) -> int: ... @guppy.extend_type(float_type_def) class Float: - @guppy.hugr_op(float_op("fabs"), CoercingChecker()) + @guppy.hugr_op(float_op("fabs")) def __abs__(self: float) -> float: ... - @guppy.hugr_op(float_op("fadd"), CoercingChecker()) + @guppy.hugr_op(float_op("fadd")) def __add__(self: float, other: float) -> float: ... - @guppy.custom(FloatBoolCompiler(), CoercingChecker()) - def __bool__(self: float) -> bool: ... + @guppy + @no_type_check + def __bool__(self: float) -> bool: + return self != 0.0 - @guppy.hugr_op(float_op("fceil"), CoercingChecker()) + @guppy.hugr_op(float_op("fceil")) def __ceil__(self: float) -> float: ... - @guppy.custom(FloatDivmodCompiler(), CoercingChecker()) - def __divmod__(self: float, other: float) -> tuple[float, float]: ... + @guppy + @no_type_check + def __divmod__(self: float, other: float) -> tuple[float, float]: + return self // other, self.__mod__(other) - @guppy.hugr_op(float_op("feq"), CoercingChecker()) + @guppy.hugr_op(float_op("feq")) def __eq__(self: float, other: float) -> bool: ... - @guppy.custom(NoopCompiler(), CoercingChecker()) + @guppy.custom(NoopCompiler()) def __float__(self: float) -> float: ... - @guppy.hugr_op(float_op("ffloor"), CoercingChecker()) + @guppy.hugr_op(float_op("ffloor")) def __floor__(self: float) -> float: ... - @guppy.custom(FloatFloordivCompiler(), CoercingChecker()) - def __floordiv__(self: float, other: float) -> float: ... + @guppy + @no_type_check + def __floordiv__(self: float, other: float) -> float: + return (self / other).__floor__() - @guppy.hugr_op(float_op("fge"), CoercingChecker()) + @guppy.hugr_op(float_op("fge")) def __ge__(self: float, other: float) -> bool: ... - @guppy.hugr_op(float_op("fgt"), CoercingChecker()) + @guppy.hugr_op(float_op("fgt")) def __gt__(self: float, other: float) -> bool: ... - @guppy.hugr_op( - unsupported_op("trunc_s"), CoercingChecker() - ) # TODO `trunc_s` returns an option + @guppy.hugr_op(unsupported_op("trunc_s")) # TODO `trunc_s` returns an option def __int__(self: float) -> int: ... - @guppy.hugr_op(float_op("fle"), CoercingChecker()) + @guppy.hugr_op(float_op("fle")) def __le__(self: float, other: float) -> bool: ... - @guppy.hugr_op(float_op("flt"), CoercingChecker()) + @guppy.hugr_op(float_op("flt")) def __lt__(self: float, other: float) -> bool: ... - @guppy.custom(FloatModCompiler(), CoercingChecker()) - def __mod__(self: float, other: float) -> float: ... + @guppy + @no_type_check + def __mod__(self: float, other: float) -> float: + return self - (self // other) * other - @guppy.hugr_op(float_op("fmul"), CoercingChecker()) + @guppy.hugr_op(float_op("fmul")) def __mul__(self: float, other: float) -> float: ... - @guppy.hugr_op( - unsupported_op("trunc_u"), CoercingChecker() - ) # TODO `trunc_u` returns an option + @guppy.hugr_op(unsupported_op("trunc_u")) # TODO `trunc_u` returns an option def __nat__(self: float) -> nat: ... - @guppy.hugr_op(float_op("fne"), CoercingChecker()) + @guppy.hugr_op(float_op("fne")) def __ne__(self: float, other: float) -> bool: ... - @guppy.hugr_op(float_op("fneg"), CoercingChecker()) + @guppy.hugr_op(float_op("fneg")) def __neg__(self: float) -> float: ... @guppy.custom(checker=DunderChecker("__float__"), higher_order_value=False) def __new__(x): ... - @guppy.custom(NoopCompiler(), CoercingChecker()) + @guppy.custom(NoopCompiler()) def __pos__(self: float) -> float: ... @guppy.hugr_op(float_op("fpow")) # TODO def __pow__(self: float, other: float) -> float: ... - @guppy.hugr_op(float_op("fadd"), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __radd__(self: float, other: float) -> float: ... - @guppy.custom(FloatDivmodCompiler(), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __rdivmod__(self: float, other: float) -> tuple[float, float]: ... - @guppy.custom(FloatFloordivCompiler(), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __rfloordiv__(self: float, other: float) -> float: ... - @guppy.custom(FloatModCompiler(), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __rmod__(self: float, other: float) -> float: ... - @guppy.hugr_op(float_op("fmul"), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __rmul__(self: float, other: float) -> float: ... @guppy.hugr_op(float_op("fround")) # TODO def __round__(self: float) -> float: ... - @guppy.hugr_op( - float_op("fpow"), - ReversingChecker(DefaultCallChecker()), - ) # TODO + @guppy.custom(checker=ReversingChecker()) def __rpow__(self: float, other: float) -> float: ... - @guppy.hugr_op(float_op("fsub"), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __rsub__(self: float, other: float) -> float: ... - @guppy.hugr_op(float_op("fdiv"), ReversingChecker(CoercingChecker())) + @guppy.custom(checker=ReversingChecker()) def __rtruediv__(self: float, other: float) -> float: ... - @guppy.hugr_op(float_op("fsub"), CoercingChecker()) + @guppy.hugr_op(float_op("fsub")) def __sub__(self: float, other: float) -> float: ... - @guppy.hugr_op(float_op("fdiv"), CoercingChecker()) + @guppy.hugr_op(float_op("fdiv")) def __truediv__(self: float, other: float) -> float: ... - @guppy.hugr_op( - unsupported_op("trunc_s"), CoercingChecker() - ) # TODO `trunc_s` returns an option + @guppy.hugr_op(unsupported_op("trunc_s")) # TODO `trunc_s` returns an option def __trunc__(self: float) -> float: ...