-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: adapt to sroa jll #521
Conversation
passes = [ | ||
"inline{default-pipeline=canonicalize max-iterations=4}" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
passes = [ | |
"inline{default-pipeline=canonicalize max-iterations=4}" | |
] | |
passes = ["inline{default-pipeline=canonicalize max-iterations=4}"] |
push!(passes, "sroa-wrappers") | ||
push!(passes, "libdevice-funcs-raise") | ||
push!(passes, "canonicalize") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
push!(passes, "sroa-wrappers") | |
push!(passes, "libdevice-funcs-raise") | |
push!(passes, "canonicalize") | |
push!(passes, "sroa-wrappers") | |
push!(passes, "libdevice-funcs-raise") | |
push!(passes, "canonicalize") |
return join(passes, | ||
',', | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return join(passes, | |
',', | |
) | |
return join(passes, ',') |
@@ -379,12 +385,20 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||
if isdefined(Reactant_jll, :ptxas_path) | |||
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] | |||
end | |||
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" | |||
if DEBUG_KERNEL[] | |||
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult") | |
curesulthandler = XLA.Libdl.dlsym( | |
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" | |
) |
@@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError) | |||
) | |||
end | |||
|
|||
function make_tracer( | |||
seen, | |||
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}), | |
@nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}), |
@@ -2,6 +2,9 @@ using Reactant | |||
using Reactant: Ops | |||
|
|||
using Test | |||
|
|||
# Jax on Github CI dislikes X86 macos | |||
@static if !Sys.isapple() || Sys.ARCH != :x86_64 | |||
using PythonCall |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
using PythonCall | |
using PythonCall |
@@ -11,3 +14,4 @@ using PythonCall | |||
@test typeof(result) == ConcreteRNumber{Float32} | |||
@test result ≈ 6 | |||
end | |||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/integration/python.jl
Lines 10 to 11 in f289522
@testset "PythonCall" begin | |
jax = pyimport("jax") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/integration/python.jl
Lines 13 to 15 in f289522
result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3])) | |
@test typeof(result) == ConcreteRNumber{Float32} | |
@test result ≈ 6 |
@@ -425,6 +428,7 @@ function get_field_offset(T::Type, path) | |||
offset = 0 | |||
current_type = T | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
||
# Update current_type to the field's type for next iteration | ||
current_type = fieldtype(current_type, field_idx) | ||
current_type = tcurrent_type | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@@ -461,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||
blockdim = CUDA.CuDim3(blocks) | |||
threaddim = CUDA.CuDim3(threads) | |||
|
|||
if convert == Val(true) | |||
args = recudaconvert.(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
args = recudaconvert.(args) | |
args = recudaconvert.(args) |
@@ -650,6 +662,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||
end | |||
|
|||
location = MLIR.IR.Location() | |||
@assert length(restys) == length(aliases) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@assert length(restys) == length(aliases) | |
@assert length(restys) == length(aliases) |
Depends on EnzymeAD/Enzyme-JAX#229
which itself depends on XLA adapting to
llvm/llvm-project#122650
llvm/llvm-project#122646
llvm/llvm-project#122615
llvm/llvm-project#122574
llvm/llvm-project#122572