From 7aa1d3909542797b6169fd0261f86cfc73c86239 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 29 May 2024 00:35:55 -0400 Subject: [PATCH] Add literal_pow transformation (#19) --- src/OverflowContexts.jl | 14 ++++++++++++++ src/checked.jl | 29 ++++++++++++++++------------- src/macros.jl | 6 ++++++ src/saturating.jl | 23 ++++++++++++++++------- src/unchecked.jl | 3 +++ test/runtests.jl | 6 ++++++ 6 files changed, 61 insertions(+), 20 deletions(-) diff --git a/src/OverflowContexts.jl b/src/OverflowContexts.jl index db0ee7d..21b606a 100644 --- a/src/OverflowContexts.jl +++ b/src/OverflowContexts.jl @@ -3,6 +3,20 @@ module OverflowContexts const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} const UnsignedBitInteger = Union{UInt8, UInt16, UInt32, UInt64, UInt128} +using Base: BitInteger, promote, afoldl, @_inline_meta +import Base: literal_pow +import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs, + checked_div, checked_fld, checked_cld, checked_mod, checked_rem +using Base.Checked: mul_with_overflow + +if VERSION ≥ v"1.11-alpha" + import Base: power_by_squaring + import Base.Checked: checked_pow +else + using Base: throw_domerr_powbysq, to_power_type + using Base.Checked: throw_overflowerr_binaryop +end + include("macros.jl") include("checked.jl") include("unchecked.jl") diff --git a/src/checked.jl b/src/checked.jl index 2433f71..5b28fec 100644 --- a/src/checked.jl +++ b/src/checked.jl @@ -1,16 +1,3 @@ -using Base: BitInteger, promote, afoldl, @_inline_meta -import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs, - checked_div, checked_fld, checked_cld, checked_mod, checked_rem -using Base.Checked: mul_with_overflow - -if VERSION ≥ v"1.11-alpha" - import Base: power_by_squaring - import Base.Checked: checked_pow -else - using Base: throw_domerr_powbysq, to_power_type - using Base.Checked: throw_overflowerr_binaryop -end - # resolve ambiguity when `-` used as symbol checked_negsub(x) = checked_neg(x) checked_negsub(x, y) = checked_sub(x, y) @@ -87,3 +74,19 @@ if VERSION < v"1.11" return y end end + +# adapted from Base intfuncs.jl; negative literal powers promote to floating point +@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{0}) = one(x) +@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{1}) = x +@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{2}) = @checked x * x +@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{3}) = @checked x * x * x +@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{-1}) = literal_pow(^, x, Val(-1)) +@inline literal_pow(::typeof(checked_pow), x::BitInteger, ::Val{-2}) = literal_pow(^, x, Val(-2)) + +@inline function literal_pow(f::typeof(checked_pow), x, ::Val{p}) where {p} + if p < 0 + literal_pow(^, x, Val(p)) + else + f(x, p) + end +end diff --git a/src/macros.jl b/src/macros.jl index 826e382..ac7ab25 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -235,12 +235,18 @@ function replace_op!(expr::Expr, op_map::Dict) Expr(:tuple, expr.args[2:end]...)] end else # arbitrary call + op_orig = op op = get(op_map, op, op) if isexpr(f, :.) f.args[2] = QuoteNode(op) expr.args[1] = f else expr.args[1] = op + if op_orig == :^ && expr.args[3] isa Integer + # literal_pow transformation + pushfirst!(expr.args, :(Base.literal_pow)) + expr.args[4] = :(Val($(expr.args[4]))) + end end end for i in 2:length(expr.args) diff --git a/src/saturating.jl b/src/saturating.jl index deaaf99..7046e3d 100644 --- a/src/saturating.jl +++ b/src/saturating.jl @@ -1,10 +1,3 @@ -import Base: BitInteger -import Base.Checked: mul_with_overflow - -if VERSION ≥ v"1.11-alpha" - using Base: power_by_squaring -end - # resolve ambiguity when `-` used as symbol saturating_negsub(x) = saturating_neg(x) saturating_negsub(x, y) = saturating_sub(x, y) @@ -156,3 +149,19 @@ function saturating_mod(x::T, y::T) where T <: SignedBitInteger end saturating_mod(x::T, y::T) where T <: UnsignedBitInteger = @saturating rem(x, y) + +# adapted from Base intfuncs.jl; negative literal powers promote to floating point +@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{0}) = one(x) +@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{1}) = x +@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{2}) = @saturating x * x +@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{3}) = @saturating x * x * x +@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{-1}) = literal_pow(^, x, Val(-1)) +@inline literal_pow(::typeof(saturating_pow), x::BitInteger, ::Val{-2}) = literal_pow(^, x, Val(-2)) + +@inline function literal_pow(f::typeof(saturating_pow), x, ::Val{p}) where {p} + if p < 0 + literal_pow(^, x, Val(p)) + else + f(x, p) + end +end diff --git a/src/unchecked.jl b/src/unchecked.jl index 26684bd..94057d6 100644 --- a/src/unchecked.jl +++ b/src/unchecked.jl @@ -67,3 +67,6 @@ unchecked_rem(x::T, y::T) where T <: UnsignedBitInteger = unchecked_mod(x::T, y::T) where T <: SignedBitInteger = x - unchecked_fld(x, y) * y unchecked_mod(x::T, y::T) where T <: UnsignedBitInteger = unchecked_rem(x, y) + +# adapted from Base intfuncs.jl; negative literal powers promote to floating point +@inline literal_pow(::typeof(unchecked_pow), x, ::Val{p}) where {p} = literal_pow(^, x, Val(p)) diff --git a/test/runtests.jl b/test/runtests.jl index 407b688..32c811e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -792,3 +792,9 @@ end @test_throws ErrorException @saturating aa * bb' @test_throws ErrorException @saturating dd ^ 2 end + +@testset "literal_pow transformation" begin + expr = @macroexpand @checked 5 ^ 2 + @test expr.args[1] == :(Base.literal_pow) + @test expr.args[2] == :(OverflowContexts.checked_pow) +end