From b47597ff65d1337b424fa799fd7e3b29950d60f7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 12 Jan 2025 19:27:54 -0500 Subject: [PATCH] WIP: adapt to sroa jll --- deps/ReactantExtra/API.cpp | 2 +- deps/ReactantExtra/WORKSPACE | 74 ++++++++++++++++++------------------ ext/ReactantCUDAExt.jl | 3 ++ src/Compiler.jl | 37 ++++++++++-------- 4 files changed, 63 insertions(+), 53 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f7ada88b2..1d0f085f8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -565,7 +565,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { prepareRegistry(registry); mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); + registerenzymexlaPasses(); // Register the standard passes we want. mlir::registerCSEPass(); diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6bdffaffd..631d8b8a6 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873" +ENZYMEXLA_COMMIT = "3b217bbfd5680ecd88c20285fe7b5693c541fa8b" ENZYMEXLA_SHA256 = "" http_archive( @@ -95,39 +95,39 @@ LLVM_TARGETS = select({ }) + ["AArch64", "X86", "ARM"] # Uncomment these lines to use a custom LLVM commit -# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908" -# LLVM_SHA256 = "" -# http_archive( -# name = "llvm-raw", -# build_file_content = "# empty", -# sha256 = LLVM_SHA256, -# strip_prefix = "llvm-project-" + LLVM_COMMIT, -# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], -# ) -# -# -# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -# maybe( -# http_archive, -# name = "llvm_zlib", -# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", -# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", -# strip_prefix = "zlib-ng-2.0.7", -# urls = [ -# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", -# ], -# ) -# -# maybe( -# http_archive, -# name = "llvm_zstd", -# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", -# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", -# strip_prefix = "zstd-1.5.2", -# urls = [ -# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" -# ], -# ) +LLVM_COMMIT = "9b4bf06be33f0fe6a4c487bb9244d8c0f6acab3f" +LLVM_SHA256 = "" +http_archive( + name = "llvm-raw", + build_file_content = "# empty", + sha256 = LLVM_SHA256, + strip_prefix = "llvm-project-" + LLVM_COMMIT, + urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], +) + + +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +maybe( + http_archive, + name = "llvm_zlib", + build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", + sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", + strip_prefix = "zlib-ng-2.0.7", + urls = [ + "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", + ], +) + +maybe( + http_archive, + name = "llvm_zstd", + build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", + sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", + strip_prefix = "zstd-1.5.2", + urls = [ + "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" + ], +) http_archive( name = "jax", @@ -138,9 +138,9 @@ http_archive( patches = ["@enzyme_ad//:patches/jax.patch"], ) -# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" -XLA_SHA256 = "" +load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") +# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" +# XLA_SHA256 = "" http_archive( name = "xla", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ca4d6efdf..ef13a6715 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -444,6 +444,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( prev = Any[func.f, args...] kernelargsym = gensym("kernelarg") Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack) + @show prev + @show Core.Typeof(prev) + @show seen wrapper_tys = MLIR.IR.Type[] for arg in values(seen) if !(arg isa TracedRArray || arg isa TracedRNumber) diff --git a/src/Compiler.jl b/src/Compiler.jl index 7bc4f29fa..52d1be8ab 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -116,7 +116,7 @@ function create_result( end # Optimization passes via transform dialect -function optimization_passes(; no_nan::Bool=false) +function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) transform_passes_list = [ "patterns=compare_op_canon<16>", "transpose_transpose<16>", @@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false) ",", ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") - return join( - [ - "inline{default-pipeline=canonicalize max-iterations=4}", - "libdevice-funcs-raise", - func_passes, - ], + passes = [ + "inline{default-pipeline=canonicalize max-iterations=4}" + ] + if sroa + push!(passes, "sroa-wrappers") + push!(passes, "libdevice-funcs-raise") + push!(passes, "canonicalize") + end + push!(passes, func_passes) + return join(passes, ',', ) end @@ -310,6 +314,8 @@ end const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) + @show pass_pipeline + flush(stdout) pm = MLIR.IR.PassManager() MLIR.IR.enable_verifier!(pm, enable_verifier) opm = MLIR.IR.OpPassManager(pm) @@ -374,9 +380,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" opt_passes = optimization_passes(; no_nan) + opt_passes2 = optimization_passes(; no_nan, sroa=false) if optimize === :all - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) @@ -387,14 +394,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, ], ',', ), ) elseif optimize === :before_kernel - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) @@ -405,13 +412,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), ) elseif optimize === :no_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, @@ -420,7 +427,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), @@ -449,14 +456,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, ], ',', ), ) elseif optimize === :before_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false )