Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add grad (autodiff) with Enzyme #439

Draft
wants to merge 40 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
70f86f6
convert HLO to StableHLO
joelberkeley Dec 1, 2024
866d15a
working sometimes
joelberkeley Dec 1, 2024
793bfa1
working! thanks kevin
joelberkeley Dec 2, 2024
e07c7f8
wip
joelberkeley Dec 2, 2024
e1daab0
Merge branch 'master' into stablehlo-convert
joelberkeley Dec 2, 2024
c8bf80c
wip
joelberkeley Dec 2, 2024
b9a5f04
temp remove pack switch HEAD
joelberkeley Dec 2, 2024
fea8646
revert
joelberkeley Dec 3, 2024
364158b
wip
joelberkeley Dec 4, 2024
efab2a5
build mlir::ModuleOp
joelberkeley Dec 4, 2024
01c3cf2
working on linux
joelberkeley Dec 8, 2024
de00a64
wip
joelberkeley Dec 8, 2024
6cea6ba
wip
joelberkeley Dec 8, 2024
4f66569
tan
joelberkeley Dec 8, 2024
cf0d80f
wip
joelberkeley Dec 13, 2024
ed838f8
llvm
joelberkeley Dec 13, 2024
59a317d
tidy
joelberkeley Dec 13, 2024
b983e13
wip
joelberkeley Dec 13, 2024
189f40a
draft AD with enzyme
joelberkeley Dec 16, 2024
b2f7a65
Merge branch 'master' into stablehlo-ad
joelberkeley Dec 16, 2024
8abda25
shellcheck
joelberkeley Dec 16, 2024
edd3f39
shellcheck
joelberkeley Dec 16, 2024
c933e6a
update enzyme version
joelberkeley Dec 16, 2024
4e5a253
wip
joelberkeley Dec 19, 2024
8828373
first (almost e2e) draft
joelberkeley Dec 23, 2024
0a5f935
compiling with runtime errors
joelberkeley Dec 23, 2024
41bc6ad
wip
joelberkeley Dec 28, 2024
ff85828
moses suggestion
joelberkeley Jan 4, 2025
6bd295d
wip
joelberkeley Jan 4, 2025
768b5b5
everything
joelberkeley Jan 4, 2025
964c875
enz version
joelberkeley Jan 4, 2025
3460cb7
revert xla version
joelberkeley Jan 4, 2025
33843fc
wip
joelberkeley Jan 4, 2025
d50956a
wip
joelberkeley Jan 4, 2025
2b6d8dc
wip
joelberkeley Jan 4, 2025
857f3da
Merge branch 'master' into stablehlo-ad
joelberkeley Jan 20, 2025
69558cb
use loadDialect
joelberkeley Jan 20, 2025
8c319ee
wip
joelberkeley Jan 20, 2025
d4332df
start debugging properly
joelberkeley Jan 22, 2025
2e8f5d4
really solid progress
joelberkeley Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion XLA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2fb20601f1cc6cab7f29f8bc73d90cd31e74bba0
b44f55da3dac449f03466815ac431474f86fd73f
17 changes: 10 additions & 7 deletions dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@ short_revision () {
echo "${rev%%"${rev##??????????}"}"
}

install_xla () {
if [ -z "$2" ]; then
echo "Usage: install_xla <xla-revision> <install-path>."
exit 1;
fi

install_git_repository () {
if [ "$(ls -A "$2")" ]; then
echo "Directory at path $2 is not empty, refusing to install XLA to this directory."
exit 1;
Expand All @@ -22,8 +17,16 @@ install_xla () {
(
cd "$2"
git init
git remote add origin https://github.com/openxla/xla
git remote add origin "$3"
git fetch --depth 1 origin "$1"
git checkout FETCH_HEAD
)
}

install_xla () {
install_git_repository "$1" "$2" https://github.com/openxla/xla
}

install_enzyme () {
install_git_repository "$1" "$2" https://github.com/EnzymeAD/Enzyme-JAX.git
}
1 change: 1 addition & 0 deletions pjrt-plugins/xla-cpu/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
xla/
3 changes: 2 additions & 1 deletion pjrt-plugins/xla-cpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ case $osu in
;;
esac

xla_dir=$(mktemp -d)
xla_dir=pjrt-plugins/xla-cpu/xla
mkdir "$xla_dir"
install_xla "$rev" "$xla_dir"
(
cd "$xla_dir"
Expand Down
3 changes: 2 additions & 1 deletion pjrt-plugins/xla-cuda/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ case $osu in
;;
esac

xla_dir=$(mktemp -d)
xla_dir=pjrt-plugins/xla-cuda/xla
mkdir "$xla_dir"
install_xla "$rev" "$xla_dir"
(
cd "$xla_dir"
Expand Down
1 change: 1 addition & 0 deletions spidr/backend/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/Enzyme-JAX
/xla
22 changes: 22 additions & 0 deletions spidr/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,43 @@ cc_binary(
linkshared = True,
linkstatic = True,
srcs = [
"//src/Enzyme-JAX/src/enzyme_ad/jax",
"//src/Enzyme-JAX/src/enzyme_ad/jax/Passes",
"//src/Enzyme/enzyme/Enzyme/MLIR/Dialect",
"//src/Enzyme/enzyme/Enzyme/MLIR/Passes",
"//src/llvm/Support",
"//src/mlir/IR",
"//src/mlir/Pass",
"//src/stablehlo/dialect",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/hlo/translate",
"//src/xla/mlir_hlo/mhlo/IR",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src/xla/service",
"//src",
],
deps = [
"//src/Enzyme-JAX/src/enzyme_ad/jax",
"//src/Enzyme-JAX/src/enzyme_ad/jax/Passes",
"//src/Enzyme/enzyme/Enzyme/MLIR/Dialect",
"//src/Enzyme/enzyme/Enzyme/MLIR/Passes",
"//src/llvm/Support",
"//src/mlir/IR",
"//src/mlir/Pass",
"//src/stablehlo/dialect",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/hlo/translate",
"//src/xla/mlir_hlo/mhlo/IR",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src/xla/service",
"//src",
],
)
1 change: 1 addition & 0 deletions spidr/backend/ENZYME_JAX_VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
b6d6563aa3a3050474a4250bf18322f7ebf0b486
2 changes: 1 addition & 1 deletion spidr/backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.16
0.0.15
44 changes: 44 additions & 0 deletions spidr/backend/WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
### xla

# this must be a local repository not http archive
# so we can run ./configure.py before invoking bazel
local_repository(name = "xla", path = "xla")
Expand Down Expand Up @@ -28,3 +30,45 @@ xla_workspace0()

load("@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure")
cuda_configure(name = "local_config_cuda")

### Enzyme-JAX
# note enzyme-jax specifies XLA versions, which we're currently ignoring. Do we need to use their versions?
local_repository(name = "enzyme-jax", path = "Enzyme-JAX")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "hedron_compile_commands",

# Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here.
# Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README).
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",
# When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..."
)
# load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")
# hedron_compile_commands_setup()
# load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive")
# hedron_compile_commands_setup_transitive()
# load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive")
# hedron_compile_commands_setup_transitive_transitive()
# load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive")
# hedron_compile_commands_setup_transitive_transitive_transitive()

load("@enzyme-jax//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256")

http_archive(
name = "jax",
sha256 = JAX_SHA256,
strip_prefix = "jax-" + JAX_COMMIT,
urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)],
patch_args = ["-p1"],
patches = ["@enzyme-jax//:patches/jax.patch"],
)

http_archive(
name = "enzyme",
sha256 = ENZYME_SHA256,
strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)
10 changes: 8 additions & 2 deletions spidr/backend/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."
. ./dev.sh
rev="$(cat XLA_VERSION)"
xla_rev="$(cat XLA_VERSION)"
enzyme_rev="$(cat spidr/backend/ENZYME_JAX_VERSION)"

osu="$(uname)"
case $osu in
Expand All @@ -26,8 +27,13 @@ esac
(
cd spidr/backend
mkdir xla
install_xla "$rev" xla
install_xla "$xla_rev" xla
(cd xla; ./configure.py --backend=cpu --os=$os)
# depending on Enzyme-JAX is problematic as it fixes the XLA version. Can we only depend on enzyme?
# seems unlikely that they could decouple XLA entirely. They almost certainly can't decouple stablehlo
mkdir Enzyme-JAX
install_enzyme "$enzyme_rev" Enzyme-JAX
cat everything >> Enzyme-JAX/BUILD
bazel build //:c_xla
rm -rf xla
)
Expand Down
54 changes: 54 additions & 0 deletions spidr/backend/everything
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

cc_library(
name = "everything",
srcs = [
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"//src/enzyme_ad/jax:RegistryUtils.cpp",
],
hdrs = [
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"//src/enzyme_ad/jax:RegistryUtils.h",
],
visibility = ["//visibility:public"],
deps = [
"@enzyme//:EnzymeMLIR",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:AsyncDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:OpenMPDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"@stablehlo//:chlo_ops",
"@stablehlo//stablehlo/tests:check_ops",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:ComplexToLLVM",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",

"@llvm-project//llvm:X86AsmParser",
"@llvm-project//llvm:X86CodeGen",
],
)
12 changes: 12 additions & 0 deletions spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
cc_library(
name = "jax",
linkstatic = True,
alwayslink = True,
srcs = glob(["*.cpp"]),
hdrs = glob(["*.h"]),
deps = [
"@enzyme-jax//:everything",
"//src/mlir/IR",
],
visibility = ["//visibility:public"],
)
33 changes: 33 additions & 0 deletions spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
cc_library(
name = "Passes",
linkstatic = True,
alwayslink = True,
srcs = glob(["*.cpp"]),
hdrs = glob(["*.h"]),
deps = [
"@xla//xla/hlo/builder:xla_builder",
"@xla//xla/hlo/translate:stablehlo",
"@xla//xla/hlo/builder/lib:math",
"@xla//xla/mlir_hlo:hlo_dialect_registration",
"@enzyme-jax//:everything",
"//src/mlir/IR",
"//src/mlir/Pass",
],
visibility = ["//visibility:public"],
)

cc_binary(
name = "example",
linkstatic = True,
srcs = glob(["*.cpp"]),
deps = [
"@xla//xla/hlo/builder:xla_builder",
"@xla//xla/hlo/translate:stablehlo",
"@xla//xla/hlo/builder/lib:math",
"@xla//xla/mlir_hlo:hlo_dialect_registration",
"@enzyme-jax//:everything",
"//src/mlir/IR",
"//src/mlir/Pass",
],
visibility = ["//visibility:public"],
)
Loading
Loading