From 6ddc890e367186107f23564aa4a1b86475a3f7d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Jan 2025 13:35:12 -0500 Subject: [PATCH 1/3] fix: generalize broadcast_in_dims for setindex (#518) * fix: generalize broadcast_in_dims for setindex * test: writing with less dims --- src/Ops.jl | 36 ++++++++++++++++++------------------ src/TracedRArray.jl | 17 +++++++++++++++-- src/TracedUtils.jl | 25 +------------------------ test/basic.jl | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 44 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index f67300787..0d94cb75d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -936,24 +936,24 @@ end end # broadcast ops -# function broadcast_in_dim( -# x::TracedRArray{T,N}, -# dims::Vector{Int}; -# location=mlir_stacktrace( -# "broadcast_in_dim", @__FILE__, @__LINE__ -# ), -# ) where {T,N} -# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x)) -# res = MLIR.IR.result( -# stablehlo.broadcast_in_dim( -# x.mlir_data; -# result_0=restype, -# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end +function broadcast_in_dim( + x::TracedRArray{T,N}, + dims::Vector{Int}, + result_size::Vector{Int}; + location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__), +) where {T,N} + @assert length(dims) == N + + res = MLIR.IR.result( + stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1), + location, + ), + ) + return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size)) +end @noinline function sort( x::TracedRArray{T,N}; diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a7d72b3ca..8b7835bcd 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -218,8 +218,21 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { return v end - v = TracedUtils.broadcast_to_size(v, length.(indices)) - v = TracedUtils.promote_to(TracedRArray{T,N}, v) + if v isa Number + v = TracedUtils.broadcast_to_size(v, length.(indices)) + v = TracedUtils.promote_to(TracedRArray{T,N}, v) + else + v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v) + non_integer_indices = [!(idx isa Integer) for idx in indices] + broadcast_dims = findall(non_integer_indices) + if length(broadcast_dims) == N + v = TracedUtils.broadcast_to_size(v, length.(indices)) + else + v = Ops.broadcast_in_dim( + materialize_traced_array(v), broadcast_dims, Int64.(length.(indices)) + ) + end + end indices = [ ( diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index ee9087557..ab7a55643 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -496,30 +496,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize) end @noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T} - dims = collect(Int64, 0:(length(size(x)) - 1)) - - if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims) - @show x - @show arg - @show rsize - @show rsize2 - @show dims - end - @assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims) - mlirty = MLIR.IR.type(get_mlir_data(x)) - - return TracedRArray{T,Int(length(rsize))}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.broadcast_in_dim( - get_mlir_data(x); - result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), - ), - 1, - ), - collect(rsize), - ) + return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize)) end end diff --git a/test/basic.jl b/test/basic.jl index 531fec16e..83336f432 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -422,6 +422,39 @@ end # get_view_compiled = @compile get_view(x_concrete) end +function write_with_broadcast1!(x, y) + x[1, :, :] .= reshape(y, 4, 3) + return x +end +function write_with_broadcast2!(x, y) + x[:, 1, :] .= view(y, :, 1:3) + return x +end + +@testset "write_with_broadcast" begin + x_ra = Reactant.to_rarray(zeros(3, 4, 3)) + y_ra = Reactant.to_rarray(rand(3, 4)) + + res = @jit write_with_broadcast1!(x_ra, y_ra) + + @test res.data === x_ra.data + + res = Array(res) + y = Array(y_ra) + @test res[1, :, :] ≈ reshape(y, 4, 3) + + x_ra = Reactant.to_rarray(zeros(3, 4, 3)) + y_ra = Reactant.to_rarray(rand(3, 4)) + + res = @jit write_with_broadcast2!(x_ra, y_ra) + + @test res.data === x_ra.data + + res = Array(res) + y = Array(y_ra) + @test res[:, 1, :] ≈ view(y, :, 1:3) +end + function masking(x) y = similar(x) y[1:2, :] .= 0 From 7c2e390e8b105da83f512789d11f1155896bad4a Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 13 Jan 2025 03:25:58 +0100 Subject: [PATCH 2/3] linearize aliased kernel args (#504) * Add NoStopTracedTrack mode and use to handle aliased inputs * aliasing test * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: William S. Moses --- ext/ReactantCUDAExt.jl | 21 +++++++++++---------- src/Tracing.jl | 28 +++++++++++++++++++++++++--- test/integration/cuda.jl | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ca4d6efdf..d849efc61 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -443,7 +443,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( seen = Reactant.OrderedIdDict() prev = Any[func.f, args...] kernelargsym = gensym("kernelarg") - Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack) + Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.NoStopTracedTrack) wrapper_tys = MLIR.IR.Type[] for arg in values(seen) if !(arg isa TracedRArray || arg isa TracedRNumber) @@ -536,16 +536,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( if !(arg isa TracedRArray || arg isa TracedRNumber) continue end - for p in Reactant.TracedUtils.get_paths(arg) + + paths = Reactant.TracedUtils.get_paths(arg) + + arg = arg.mlir_data + arg = Reactant.TracedUtils.transpose_val(arg) + push!(restys, MLIR.IR.type(arg)) + push!(mlir_args, arg) + + for p in paths if p[1] !== kernelargsym continue end - - arg = arg.mlir_data - arg = Reactant.TracedUtils.transpose_val(arg) - push!(restys, MLIR.IR.type(arg)) - push!(mlir_args, arg) - # Get the allocation corresponding to which arg we're doing alloc = allocs[p[2]][1] @@ -580,9 +582,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ), ), ) - - argidx += 1 end + argidx += 1 end MLIR.IR.block!(wrapbody) do diff --git a/src/Tracing.jl b/src/Tracing.jl index e00fdcb00..2f4885147 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -4,6 +4,7 @@ TracedToConcrete = 3 ArrayToConcrete = 4 TracedSetPath = 5 + NoStopTracedTrack = 6 end for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber) @@ -249,7 +250,7 @@ function traced_type( @inline base_typec(TV::TT) where {TT<:DataType} = (T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} return base_typec(T) - elseif mode == TracedTrack || mode == TracedSetPath + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath return T else throw("Abstract RArray $T cannot be made concrete in mode $mode") @@ -261,7 +262,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:Trac throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete return ConcreteRNG - elseif mode == TracedTrack || mode == TracedSetPath + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath return T else throw("Unsupported mode: $mode") @@ -329,7 +330,7 @@ function make_tracer( track_numbers=(), kwargs..., ) where {RT} - if haskey(seen, prev) + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end TT = traced_type(RT, (), Val(mode), track_numbers) @@ -460,6 +461,13 @@ function make_tracer( end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath if haskey(seen, prev) return seen[prev] @@ -506,6 +514,13 @@ function make_tracer( end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath if haskey(seen, prev) return seen[prev] @@ -546,6 +561,13 @@ function make_tracer( end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath haskey(seen, prev) && return seen[prev] res = MissingTracedValue((path,)) diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index ca445e3e2..817dfa740 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -115,4 +115,40 @@ tuplef2(a) = @cuda threads = 1 tuplef2!((5, a)) @code_hlo optimize = :before_kernel tuplef2(A) end end + A = ConcreteRArray(fill(1)) + if CUDA.functional() + @jit tuplef2(A) + @test all(Array(A) .≈ 5) + else + @code_hlo optimize = :before_kernel tuplef2(A) + end +end + +# TODO this same code fails if we use a 0-d array...? +# maybe weird cuda things +function aliased!(tup) + x, y = tup + x[2][1] *= y[2][1] + return nothing +end + +function aliased(s) + tup = (s, s) + @cuda threads = 1 aliased!(tup) + return nothing +end + +@static if !Sys.isapple() + @testset "Aliasing arguments" begin + a = ConcreteRArray([3]) + + s = (10, a) + + if CUDA.functional() + @jit aliased((s, s)) + @test all(Array(a) == 9) + else + @code_hlo optimize = :before_kernel aliased(s) + end + end end From e718f449ab061484b7ecaaa513020f6587e65d72 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 12 Jan 2025 21:09:28 -0600 Subject: [PATCH 3/3] Kernel: support constant input arg (#522) * Kernel: support constant input arg * Update utils.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantCUDAExt.jl | 44 ++++++++++++++++++++++++++++++++++------ src/Compiler.jl | 19 +++++++++++++---- src/utils.jl | 6 +++++- test/integration/cuda.jl | 25 +++++++++++++++++++++++ 4 files changed, 83 insertions(+), 11 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index d849efc61..9f787efba 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -9,6 +9,12 @@ using Adapt struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} ptr::Core.LLVMPtr{T,A} + + function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size} + push!(Reactant.Compiler.context_gc_vector[MLIR.IR.context()], xs) + ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) + return new(ptr) + end end function Base.show(io::IO, a::AT) where {AT<:CuTracedArray} @@ -211,10 +217,34 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M return _derived_array(a, T, dims) end -function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} - res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}( - Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) +struct ReactantKernelAdaptor end + +function Adapt.adapt_storage(to::ReactantKernelAdaptor, p::CUDA.CuPtr) + return error("Cannot convert CuPtr argument of Reactant Kernel") +end +function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::DenseCuArray) + return Adapt.adapt_storage(ka, Array(xs)) +end +function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::Array) + return Adapt.adapt_storage(ka, Reactant.Ops.constant(xs)) +end +function Adapt.adapt_structure(to::ReactantKernelAdaptor, ref::Base.RefValue) + return error("Cannot convert RefValue argument of Reactant Kernel") +end +function Adapt.adapt_structure( + to::ReactantKernelAdaptor, bc::Broadcast.Broadcasted{Style,<:Any,Type{T}} +) where {Style,T} + return Broadcast.Broadcasted{Style}( + (x...) -> T(x...), Adapt.adapt(to, bc.args), bc.axes ) +end + +Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) + return adapt(ReactantKernelAdaptor(), arg) +end + +function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N} + res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs) return res end @@ -383,7 +413,8 @@ end function Reactant.make_tracer( seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs... ) - x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray + x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr)) + x = x::TracedRArray Reactant.make_tracer(seen, x, path, mode; kwargs...) return prev end @@ -441,9 +472,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( # linearize kernel arguments seen = Reactant.OrderedIdDict() - prev = Any[func.f, args...] kernelargsym = gensym("kernelarg") - Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.NoStopTracedTrack) + for (i, prev) in enumerate(Any[func.f, args...]) + Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack) + end wrapper_tys = MLIR.IR.Type[] for arg in values(seen) if !(arg isa TracedRArray || arg isa TracedRNumber) diff --git a/src/Compiler.jl b/src/Compiler.jl index 7bc4f29fa..d6565f761 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -318,26 +318,34 @@ function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) return mod end +const context_gc_vector = Dict{MLIR.IR.Context,Vector{TracedRArray}}() + # helper for debug purposes: String -> Text function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true) ctx = MLIR.IR.Context(Reactant.registry[], false) + context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - MLIR.IR.context!(ctx) do + result = MLIR.IR.context!(ctx) do mod = parse(MLIR.IR.Module, source) run_pass_pipeline!(mod, pass_pipeline; enable_verifier) MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true) Text(repr(mod)) end + Base.delete!(context_gc_vector, ctx) + return result end function compile_mlir(f, args; kwargs...) ctx = MLIR.IR.Context(Reactant.registry[], false) + context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - MLIR.IR.context!(ctx) do + results = MLIR.IR.context!(ctx) do mod = MLIR.IR.Module(MLIR.IR.Location()) evalinfo = compile_mlir!(mod, f, args; kwargs...) - return mod, evalinfo... + return (mod, evalinfo...) end + Base.delete!(context_gc_vector, ctx) + return results end const cuLaunch = Ref{UInt}(0) @@ -859,10 +867,11 @@ end function compile_xla(f, args; client=nothing, optimize=true, no_nan=false) # register MLIR dialects ctx = MLIR.IR.Context(Reactant.registry[], false) + context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid MLIR.IR.activate!(ctx) - return try + results = try # compile function to MLIR module mod = MLIR.IR.Module(MLIR.IR.Location()) linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!( @@ -888,6 +897,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false) finally MLIR.IR.deactivate!(ctx) end + Base.delete!(context_gc_vector, ctx) + return results end function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=false) diff --git a/src/utils.jl b/src/utils.jl index 2f79036cf..2a4cb185a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -195,7 +195,11 @@ struct MustThrowError end @generated function applyiterate_with_reactant( iteratefn, applyfn, args::Vararg{Any,N} ) where {N} - @assert iteratefn == typeof(Base.iterate) + if iteratefn != typeof(Base.iterate) + return quote + error("Unhandled apply_iterate with iteratefn=$iteratefn") + end + end newargs = Vector{Expr}(undef, N) for i in 1:N @inbounds newargs[i] = :(args[$i]...) diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index 817dfa740..da5d3c52b 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -152,3 +152,28 @@ end end end end + +using Reactant, CUDA + +function cmul!(a, b) + b[1] *= a[1] + return nothing +end + +function mixed(a, b) + @cuda threads = 1 cmul!(a, b) + return nothing +end + +@static if !Sys.isapple() + @testset "Non-traced argument" begin + if CUDA.functional() + a = CuArray([4]) + b = ConcreteRArray([3]) + + @jit mixed(a, b) + @test all(Array(a) == 4) + @test all(Array(b) == 12) + end + end +end