Skip to content

Commit

Permalink
WIP: adapt to sroa jll
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 13, 2025
1 parent ca98c17 commit b47597f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 53 deletions.
2 changes: 1 addition & 1 deletion deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
prepareRegistry(registry);

mlir::registerenzymePasses();
regsiterenzymeXLAPasses();
registerenzymexlaPasses();

// Register the standard passes we want.
mlir::registerCSEPass();
Expand Down
74 changes: 37 additions & 37 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
ENZYMEXLA_COMMIT = "3b217bbfd5680ecd88c20285fe7b5693c541fa8b"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down Expand Up @@ -95,39 +95,39 @@ LLVM_TARGETS = select({
}) + ["AArch64", "X86", "ARM"]

# Uncomment these lines to use a custom LLVM commit
# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908"
# LLVM_SHA256 = ""
# http_archive(
# name = "llvm-raw",
# build_file_content = "# empty",
# sha256 = LLVM_SHA256,
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
# )
#
#
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
# maybe(
# http_archive,
# name = "llvm_zlib",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
# strip_prefix = "zlib-ng-2.0.7",
# urls = [
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
# ],
# )
#
# maybe(
# http_archive,
# name = "llvm_zstd",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
# strip_prefix = "zstd-1.5.2",
# urls = [
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
# ],
# )
LLVM_COMMIT = "9b4bf06be33f0fe6a4c487bb9244d8c0f6acab3f"
LLVM_SHA256 = ""
http_archive(
name = "llvm-raw",
build_file_content = "# empty",
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-" + LLVM_COMMIT,
urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
)


load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
maybe(
http_archive,
name = "llvm_zlib",
build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
strip_prefix = "zlib-ng-2.0.7",
urls = [
"https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
],
)

maybe(
http_archive,
name = "llvm_zstd",
build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
strip_prefix = "zstd-1.5.2",
urls = [
"https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
],
)

http_archive(
name = "jax",
Expand All @@ -138,9 +138,9 @@ http_archive(
patches = ["@enzyme_ad//:patches/jax.patch"],
)

# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
XLA_SHA256 = ""
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
# XLA_SHA256 = ""

http_archive(
name = "xla",
Expand Down
3 changes: 3 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
prev = Any[func.f, args...]
kernelargsym = gensym("kernelarg")
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
@show prev
@show Core.Typeof(prev)
@show seen
wrapper_tys = MLIR.IR.Type[]
for arg in values(seen)
if !(arg isa TracedRArray || arg isa TracedRNumber)
Expand Down
37 changes: 22 additions & 15 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function create_result(
end

# Optimization passes via transform dialect
function optimization_passes(; no_nan::Bool=false)
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down Expand Up @@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
",",
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
return join(
[
"inline{default-pipeline=canonicalize max-iterations=4}",
"libdevice-funcs-raise",
func_passes,
],
passes = [
"inline{default-pipeline=canonicalize max-iterations=4}"
]
if sroa
push!(passes, "sroa-wrappers")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")
end
push!(passes, func_passes)
return join(passes,
',',
)
end
Expand All @@ -310,6 +314,8 @@ end
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"

function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
@show pass_pipeline
flush(stdout)
pm = MLIR.IR.PassManager()
MLIR.IR.enable_verifier!(pm, enable_verifier)
opm = MLIR.IR.OpPassManager(pm)
Expand Down Expand Up @@ -374,9 +380,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"

opt_passes = optimization_passes(; no_nan)
opt_passes2 = optimization_passes(; no_nan, sroa=false)

if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
Expand All @@ -387,14 +394,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
kern,
],
',',
),
)
elseif optimize === :before_kernel
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
Expand All @@ -405,13 +412,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
],
',',
),
)
elseif optimize === :no_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
Expand All @@ -420,7 +427,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
],
',',
),
Expand Down Expand Up @@ -449,14 +456,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
kern,
],
',',
),
)
elseif optimize === :before_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
Expand Down

0 comments on commit b47597f

Please sign in to comment.