diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 9f787efba..2c6870f2c 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -328,7 +328,9 @@ function compile(job) # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false ) - GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job)) + if !Reactant.precompiling() + GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job)) + end entryname = LLVM.name(meta.entry) GPUCompiler.optimize_module!(job, mod) @@ -788,4 +790,31 @@ function __init__() return nothing end +@static if !Sys.isapple() && Sys.ARCH != :aarch64 + Reactant.PrecompileTools.@setup_workload begin + Reactant.initialize_dialect() + client = Reactant.XLA.CPUClient(; checkcount=false) + Reactant.PrecompileTools.@compile_workload begin + @static if Reactant.precompilation_supported() + function square_kernel!(x) + i = CUDA.threadIdx().x + x[i] *= x[i] + return nothing + end + + function square!(x) + CUDA.@cuda blocks = 1 threads = length(x) square_kernel!(x) + return nothing + end + y = Reactant.ConcreteRArray([2.0]; client) + Reactant.Compiler.compile_mlir(square!, (y,); optimize=false) + end + end + Reactant.XLA.free_client(client) + client.client = C_NULL + Reactant.deinitialize_dialect() + Reactant.clear_oc_cache() + end +end + end # module ReactantCUDAExt diff --git a/src/Precompile.jl b/src/Precompile.jl index 98c60dee5..a38d4720b 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -1,3 +1,4 @@ +using PrecompileTools using PrecompileTools: @setup_workload, @compile_workload function infer_sig(sig) @@ -34,15 +35,33 @@ function infer_sig(sig) end end +function clear_oc_cache() + # Opaque closures capture the worldage of their compilation and thus are not relocatable + # Therefore we explicitly purge all OC's we have created here + for v in oc_capture_vec + if v isa Base.RefValue + p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) + Base.atomic_pointerset(p, C_NULL, :monotonic) + else + empty!(v) + end + end +end + +# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 +function precompilation_supported() + return VERSION >= v"1.11" || VERSION >= v"1.10.8" +end + +function precompiling() + return (@ccall jl_generating_output()::Cint) == 1 +end + @setup_workload begin initialize_dialect() client = XLA.CPUClient(; checkcount=false) @compile_workload begin - # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 - @static if VERSION < v"1.11" - else - # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) - # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) + @static if precompilation_supported() x = ConcreteRNumber(2.0; client) Reactant.compile(sin, (x,); client) @@ -53,14 +72,5 @@ end XLA.free_client(client) client.client = C_NULL deinitialize_dialect() - # Opaque closures capture the worldage of their compilation and thus are not relocatable - # Therefore we explicitly purge all OC's we have created here - for v in oc_capture_vec - if v isa Base.RefValue - p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) - Base.atomic_pointerset(p, C_NULL, :monotonic) - else - empty!(v) - end - end + clear_oc_cache() end