Skip to content

Commit

Permalink
Generalize precompilation support (#534)
Browse files Browse the repository at this point in the history
* Generalize precompilation support

* CUDA precompilation

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* fixup

* fix

* Update ReactantCUDAExt.jl

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Jan 15, 2025
1 parent 4b9434e commit b436b48
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
31 changes: 30 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ function compile(job)
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
)

GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job))
if !Reactant.precompiling()
GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job))
end
entryname = LLVM.name(meta.entry)

GPUCompiler.optimize_module!(job, mod)
Expand Down Expand Up @@ -788,4 +790,31 @@ function __init__()
return nothing
end

@static if !Sys.isapple() && Sys.ARCH != :aarch64
Reactant.PrecompileTools.@setup_workload begin
Reactant.initialize_dialect()
client = Reactant.XLA.CPUClient(; checkcount=false)
Reactant.PrecompileTools.@compile_workload begin
@static if Reactant.precompilation_supported()
function square_kernel!(x)
i = CUDA.threadIdx().x
x[i] *= x[i]
return nothing
end

function square!(x)
CUDA.@cuda blocks = 1 threads = length(x) square_kernel!(x)
return nothing
end
y = Reactant.ConcreteRArray([2.0]; client)
Reactant.Compiler.compile_mlir(square!, (y,); optimize=false)
end
end
Reactant.XLA.free_client(client)
client.client = C_NULL
Reactant.deinitialize_dialect()
Reactant.clear_oc_cache()
end
end

end # module ReactantCUDAExt
40 changes: 25 additions & 15 deletions src/Precompile.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using PrecompileTools
using PrecompileTools: @setup_workload, @compile_workload

function infer_sig(sig)
Expand Down Expand Up @@ -34,15 +35,33 @@ function infer_sig(sig)
end
end

function clear_oc_cache()
# Opaque closures capture the worldage of their compilation and thus are not relocatable
# Therefore we explicitly purge all OC's we have created here
for v in oc_capture_vec
if v isa Base.RefValue
p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v))
Base.atomic_pointerset(p, C_NULL, :monotonic)
else
empty!(v)
end
end
end

# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
function precompilation_supported()
return VERSION >= v"1.11" || VERSION >= v"1.10.8"
end

function precompiling()
return (@ccall jl_generating_output()::Cint) == 1
end

@setup_workload begin
initialize_dialect()
client = XLA.CPUClient(; checkcount=false)
@compile_workload begin
# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
@static if VERSION < v"1.11"
else
# infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}})
# infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}})
@static if precompilation_supported()
x = ConcreteRNumber(2.0; client)
Reactant.compile(sin, (x,); client)

Expand All @@ -53,14 +72,5 @@ end
XLA.free_client(client)
client.client = C_NULL
deinitialize_dialect()
# Opaque closures capture the worldage of their compilation and thus are not relocatable
# Therefore we explicitly purge all OC's we have created here
for v in oc_capture_vec
if v isa Base.RefValue
p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v))
Base.atomic_pointerset(p, C_NULL, :monotonic)
else
empty!(v)
end
end
clear_oc_cache()
end

0 comments on commit b436b48

Please sign in to comment.