From 282548520d232c2f78af6099b809c5d3df3c1122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 10 Jan 2025 16:01:56 +0100 Subject: [PATCH 1/7] Split `should_rewrite_ft` for `call` and `invoke` expressions --- src/utils.jl | 25 +++++++++++++++++++++---- test/compile.jl | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d1130ed90..b9d6e9496 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -89,7 +89,7 @@ function has_ancestor(query::Module, target::Module) end end -function should_rewrite_ft(@nospecialize(ft)) +function should_rewrite_call(@nospecialize(ft)) # Don't rewrite builtin or intrinsics if ft <: Core.IntrinsicFunction || ft <: Core.Builtin return false @@ -171,6 +171,20 @@ function should_rewrite_ft(@nospecialize(ft)) return true end +# by default, same as `should_rewrite_call` +function should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) + Core.println(Core.stderr, "[should_rewrite_invoke] ft = $ft, parameters = $(args)") + return should_rewrite_call(ft) +end + +function should_rewrite_invoke(::typeof(Base.unique), @nospecialize(args)) + Core.println(Core.stderr, "[should_rewrite_invoke] Base.unique catched!") + if args === Tuple{Vector{Symbol}} + return false + end + return true +end + # Avoid recursively interpreting into methods we define explicitly # as overloads, which we assume should handle the entirety of the # translation (and if not they can use call_in_reactant). @@ -231,7 +245,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) end if ft == typeof(Core._apply_iterate) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) - if Base.invokelatest(should_rewrite_ft, ft) + if Base.invokelatest(should_rewrite_call, ft) if RT === Union{} rep = Expr( :call, @@ -245,7 +259,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) return true, rep, Any end end - elseif Base.invokelatest(should_rewrite_ft, ft) + elseif Base.invokelatest(should_rewrite_call, ft) if RT === Union{} rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...) return true, rep, Union{} @@ -259,10 +273,13 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) omi = inst.args[1]::Core.MethodInstance sig = omi.specTypes ft = sig.parameters[1] + argsig = sig.parameters[2:end] if ft == typeof(Core.kwcall) ft = sig.parameters[3] + argsig = sig.parameters[4:end] end - if Base.invokelatest(should_rewrite_ft, ft) && !is_reactant_method(omi) + argsig = Core.apply_type(Core.Tuple, argsig...) + if Base.invokelatest(should_rewrite_invoke, ft, argsig) && !is_reactant_method(omi) method = omi.def::Core.Method min_world = Ref{UInt}(typemin(UInt)) diff --git a/test/compile.jl b/test/compile.jl index 1f780a4ef..680b59243 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -128,7 +128,7 @@ end @test !occursin("add", repr(hlo)) end -# While a bit specific, the following is used to check for a bug in `should_rewrite_ft` +# While a bit specific, the following is used to check for a bug in `should_rewrite_call` function sinusoidal_embedding( x::AbstractArray{T,4}, min_freq, max_freq, embedding_dims::Int ) where {T} From e90e79fbb6dc7d6b25fb6bc655eef2009dfb8caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 10 Jan 2025 17:06:51 +0100 Subject: [PATCH 2/7] fix dispatch on `unique` --- src/utils.jl | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b9d6e9496..4af84fa3a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -172,18 +172,11 @@ function should_rewrite_call(@nospecialize(ft)) end # by default, same as `should_rewrite_call` -function should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) - Core.println(Core.stderr, "[should_rewrite_invoke] ft = $ft, parameters = $(args)") - return should_rewrite_call(ft) -end +should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_call(ft) -function should_rewrite_invoke(::typeof(Base.unique), @nospecialize(args)) - Core.println(Core.stderr, "[should_rewrite_invoke] Base.unique catched!") - if args === Tuple{Vector{Symbol}} - return false - end - return true -end +# fixes #493 +# TODO we probably want to skip rewrite if args do not contain Reactant types +should_rewrite_invoke(::Type{typeof(Base.unique)}, @nospecialize(::Type{Tuple{Vector{Symbol}}})) = false # Avoid recursively interpreting into methods we define explicitly # as overloads, which we assume should handle the entirety of the From 96745de0b953ce302700d56b3b840e60bccf6544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 10 Jan 2025 17:32:39 +0100 Subject: [PATCH 3/7] fix syntax problem --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 4af84fa3a..7692c4e2f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -176,7 +176,7 @@ should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_c # fixes #493 # TODO we probably want to skip rewrite if args do not contain Reactant types -should_rewrite_invoke(::Type{typeof(Base.unique)}, @nospecialize(::Type{Tuple{Vector{Symbol}}})) = false +should_rewrite_invoke(::Type{typeof(Base.unique)}, ::Type{Tuple{Vector{Symbol}}}) = false # Avoid recursively interpreting into methods we define explicitly # as overloads, which we assume should handle the entirety of the From 39a9ba061361150e27f5a74fcbfe771dff36d686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 12 Jan 2025 22:25:10 +0100 Subject: [PATCH 4/7] test --- test/compile.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/compile.jl b/test/compile.jl index 680b59243..9fab2e60c 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -146,3 +146,9 @@ end x_ra = Reactant.to_rarray(rand(Float32, 1, 1, 1, 4)) hlo = @code_hlo sinusoidal_embedding(x_ra, 0.1, 10.0, 4) end + +# test #493 +@testset "unique(::Vector{Symbol}) (#493)" begin + x = [:a, :b, :a] + @test @jit(unique(x_ra)) == [:a, :b] +end From 2c103986a55d5774c3dc7a3115648bce52772b82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 12 Jan 2025 23:09:07 +0100 Subject: [PATCH 5/7] fix typo --- test/compile.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compile.jl b/test/compile.jl index 9fab2e60c..c21f5bfca 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -150,5 +150,5 @@ end # test #493 @testset "unique(::Vector{Symbol}) (#493)" begin x = [:a, :b, :a] - @test @jit(unique(x_ra)) == [:a, :b] + @test @jit(unique(x)) == [:a, :b] end From 004722091469f72668ffdea1501533ffeaab583c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Jan 2025 15:12:18 +0100 Subject: [PATCH 6/7] Replace `should_rewrite_invoke` of `unique` for overlayed method on `_unique_dims` --- src/Overlay.jl | 9 +++++++++ src/utils.jl | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 7c41346e9..d597cce61 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -142,3 +142,12 @@ end end end end + +## fixes #493 +@reactant_overlay @noinline function Base._unique_dims(A::AbstractArray, dims::Colon) + if use_overlayed_version(A) + error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.") + else + Base.inferencebarrier(Base._unique_dims)(A, dims) + end +end diff --git a/src/utils.jl b/src/utils.jl index 988dc8de3..003035e42 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -183,7 +183,7 @@ should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_c # fixes #493 # TODO we probably want to skip rewrite if args do not contain Reactant types -should_rewrite_invoke(::Type{typeof(Base.unique)}, ::Type{Tuple{Vector{Symbol}}}) = false +# should_rewrite_invoke(::Type{typeof(Base.unique)}, ::Type{Tuple{Vector{Symbol}}}) = false # Avoid recursively interpreting into methods we define explicitly # as overloads, which we assume should handle the entirety of the From 5f6c31c13a9c094cd26243d528a1dc57ae101bc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 15 Jan 2025 18:32:01 +0100 Subject: [PATCH 7/7] remove previous solution --- src/utils.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 003035e42..61e522fe3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -181,10 +181,6 @@ end # by default, same as `should_rewrite_call` should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_call(ft) -# fixes #493 -# TODO we probably want to skip rewrite if args do not contain Reactant types -# should_rewrite_invoke(::Type{typeof(Base.unique)}, ::Type{Tuple{Vector{Symbol}}}) = false - # Avoid recursively interpreting into methods we define explicitly # as overloads, which we assume should handle the entirety of the # translation (and if not they can use call_in_reactant).