Skip to content

Commit

Permalink
Merge branch 'main' into sroa
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 13, 2025
2 parents b47597f + e718f44 commit 2b7218c
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 70 deletions.
66 changes: 48 additions & 18 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ using Adapt

struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
ptr::Core.LLVMPtr{T,A}

function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
push!(Reactant.Compiler.context_gc_vector[MLIR.IR.context()], xs)
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
return new(ptr)
end
end

function Base.show(io::IO, a::AT) where {AT<:CuTracedArray}
Expand Down Expand Up @@ -211,10 +217,34 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
return _derived_array(a, T, dims)
end

function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(
Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
struct ReactantKernelAdaptor end

function Adapt.adapt_storage(to::ReactantKernelAdaptor, p::CUDA.CuPtr)
return error("Cannot convert CuPtr argument of Reactant Kernel")
end
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::DenseCuArray)
return Adapt.adapt_storage(ka, Array(xs))
end
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::Array)
return Adapt.adapt_storage(ka, Reactant.Ops.constant(xs))
end
function Adapt.adapt_structure(to::ReactantKernelAdaptor, ref::Base.RefValue)
return error("Cannot convert RefValue argument of Reactant Kernel")
end
function Adapt.adapt_structure(
to::ReactantKernelAdaptor, bc::Broadcast.Broadcasted{Style,<:Any,Type{T}}
) where {Style,T}
return Broadcast.Broadcasted{Style}(
(x...) -> T(x...), Adapt.adapt(to, bc.args), bc.axes
)
end

Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
return adapt(ReactantKernelAdaptor(), arg)
end

function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
return res
end

Expand Down Expand Up @@ -383,7 +413,8 @@ end
function Reactant.make_tracer(
seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...
)
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))
x = x::TracedRArray
Reactant.make_tracer(seen, x, path, mode; kwargs...)
return prev
end
Expand Down Expand Up @@ -441,12 +472,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

# linearize kernel arguments
seen = Reactant.OrderedIdDict()
prev = Any[func.f, args...]
kernelargsym = gensym("kernelarg")
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
@show prev
@show Core.Typeof(prev)
@show seen
for (i, prev) in enumerate(Any[func.f, args...])
Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
end
wrapper_tys = MLIR.IR.Type[]
for arg in values(seen)
if !(arg isa TracedRArray || arg isa TracedRNumber)
Expand Down Expand Up @@ -539,16 +568,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
for p in Reactant.TracedUtils.get_paths(arg)

paths = Reactant.TracedUtils.get_paths(arg)

arg = arg.mlir_data
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

for p in paths
if p[1] !== kernelargsym
continue
end

arg = arg.mlir_data
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

# Get the allocation corresponding to which arg we're doing
alloc = allocs[p[2]][1]

Expand Down Expand Up @@ -583,9 +614,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
),
),
)

argidx += 1
end
argidx += 1
end

MLIR.IR.block!(wrapbody) do
Expand Down
19 changes: 15 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,26 +324,34 @@ function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
return mod
end

const context_gc_vector = Dict{MLIR.IR.Context,Vector{TracedRArray}}()

# helper for debug purposes: String -> Text
function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true)
ctx = MLIR.IR.Context(Reactant.registry[], false)
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
MLIR.IR.context!(ctx) do
result = MLIR.IR.context!(ctx) do
mod = parse(MLIR.IR.Module, source)
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
Text(repr(mod))
end
Base.delete!(context_gc_vector, ctx)
return result
end

function compile_mlir(f, args; kwargs...)
ctx = MLIR.IR.Context(Reactant.registry[], false)
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
MLIR.IR.context!(ctx) do
results = MLIR.IR.context!(ctx) do
mod = MLIR.IR.Module(MLIR.IR.Location())
evalinfo = compile_mlir!(mod, f, args; kwargs...)
return mod, evalinfo...
return (mod, evalinfo...)
end
Base.delete!(context_gc_vector, ctx)
return results
end

const cuLaunch = Ref{UInt}(0)
Expand Down Expand Up @@ -866,10 +874,11 @@ end
function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
# register MLIR dialects
ctx = MLIR.IR.Context(Reactant.registry[], false)
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

MLIR.IR.activate!(ctx)
return try
results = try
# compile function to MLIR module
mod = MLIR.IR.Module(MLIR.IR.Location())
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
Expand All @@ -895,6 +904,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
finally
MLIR.IR.deactivate!(ctx)
end
Base.delete!(context_gc_vector, ctx)
return results
end

function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=false)
Expand Down
36 changes: 18 additions & 18 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -936,24 +936,24 @@ end
end

# broadcast ops
# function broadcast_in_dim(
# x::TracedRArray{T,N},
# dims::Vector{Int};
# location=mlir_stacktrace(
# "broadcast_in_dim", @__FILE__, @__LINE__
# ),
# ) where {T,N}
# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x))
# res = MLIR.IR.result(
# stablehlo.broadcast_in_dim(
# x.mlir_data;
# result_0=restype,
# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
# location,
# ),
# )
# return TracedRArray{T,N}((), res, size(x))
# end
function broadcast_in_dim(
x::TracedRArray{T,N},
dims::Vector{Int},
result_size::Vector{Int};
location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__),
) where {T,N}
@assert length(dims) == N

res = MLIR.IR.result(
stablehlo.broadcast_in_dim(
x.mlir_data;
result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)),
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1),
location,
),
)
return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size))
end

@noinline function sort(
x::TracedRArray{T,N};
Expand Down
17 changes: 15 additions & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,21 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
return v
end

v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
if v isa Number
v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
else
v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v)
non_integer_indices = [!(idx isa Integer) for idx in indices]
broadcast_dims = findall(non_integer_indices)
if length(broadcast_dims) == N
v = TracedUtils.broadcast_to_size(v, length.(indices))
else
v = Ops.broadcast_in_dim(
materialize_traced_array(v), broadcast_dims, Int64.(length.(indices))
)
end
end

indices = [
(
Expand Down
25 changes: 1 addition & 24 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,30 +496,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize)
end

@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T}
dims = collect(Int64, 0:(length(size(x)) - 1))

if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims)
@show x
@show arg
@show rsize
@show rsize2
@show dims
end
@assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims)
mlirty = MLIR.IR.type(get_mlir_data(x))

return TracedRArray{T,Int(length(rsize))}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.broadcast_in_dim(
get_mlir_data(x);
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
),
1,
),
collect(rsize),
)
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
end

end
28 changes: 25 additions & 3 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TracedToConcrete = 3
ArrayToConcrete = 4
TracedSetPath = 5
NoStopTracedTrack = 6
end

for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
Expand Down Expand Up @@ -249,7 +250,7 @@ function traced_type(
@inline base_typec(TV::TT) where {TT<:DataType} =
(T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
return base_typec(T)
elseif mode == TracedTrack || mode == TracedSetPath
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
return T
else
throw("Abstract RArray $T cannot be made concrete in mode $mode")
Expand All @@ -261,7 +262,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:Trac
throw("TracedRNG cannot be traced")
elseif mode == TracedToConcrete
return ConcreteRNG
elseif mode == TracedTrack || mode == TracedSetPath
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
return T
else
throw("Unsupported mode: $mode")
Expand Down Expand Up @@ -329,7 +330,7 @@ function make_tracer(
track_numbers=(),
kwargs...,
) where {RT}
if haskey(seen, prev)
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
TT = traced_type(RT, (), Val(mode), track_numbers)
Expand Down Expand Up @@ -460,6 +461,13 @@ function make_tracer(
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
if haskey(seen, prev)
return seen[prev]
Expand Down Expand Up @@ -506,6 +514,13 @@ function make_tracer(
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
if haskey(seen, prev)
return seen[prev]
Expand Down Expand Up @@ -546,6 +561,13 @@ function make_tracer(
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
haskey(seen, prev) && return seen[prev]
res = MissingTracedValue((path,))
Expand Down
6 changes: 5 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ struct MustThrowError end
@generated function applyiterate_with_reactant(
iteratefn, applyfn, args::Vararg{Any,N}
) where {N}
@assert iteratefn == typeof(Base.iterate)
if iteratefn != typeof(Base.iterate)
return quote
error("Unhandled apply_iterate with iteratefn=$iteratefn")
end
end
newargs = Vector{Expr}(undef, N)
for i in 1:N
@inbounds newargs[i] = :(args[$i]...)
Expand Down
Loading

0 comments on commit 2b7218c

Please sign in to comment.