Skip to content

Commit

Permalink
Move Julia bindings to MLIR dialects out of JLL (#166)
Browse files Browse the repository at this point in the history
* Rename generated dialect file names

* Add dialect binding file generator script

* Generate dialect files

* Automatize dialect regeneration

* Fix paths

* Fix docs generation
  • Loading branch information
mofeing authored Oct 9, 2024
1 parent babfa2f commit 2874e0d
Show file tree
Hide file tree
Showing 13 changed files with 9,796 additions and 41 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/regenerate-dialects.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Regenerate MLIR Dialects
on:
schedule:
- cron: '0 0 * * *'
workflow_dispatch:
jobs:
make:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: julia-actions/setup-julia@v1
with:
version: '1.10'
- uses: actions/checkout@v4
with:
ref: main
- run: julia deps/ReactantExtra/make-dialects.jl
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Regenerate MLIR Dialects
title: 'Regenerate MLIR Dialects'
branch: regenerate-dialects
delete-branch: true
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
16 changes: 8 additions & 8 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ gentbl_cc_library(
name = "BuiltinJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Builtin.inc.jl"
"Builtin.jl"
)
],
td_file = "@llvm-project//mlir:include/mlir/IR/BuiltinOps.td",
Expand All @@ -420,7 +420,7 @@ gentbl_cc_library(
name = "ArithJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Arith.inc.jl"
"Arith.jl"
)
],
td_file = "@llvm-project//mlir:include/mlir/Dialect/Arith/IR/ArithOps.td",
Expand All @@ -434,7 +434,7 @@ gentbl_cc_library(
name = "AffineJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Affine.inc.jl"
"Affine.jl"
)
],
td_file = "@llvm-project//mlir:include/mlir/Dialect/Affine/IR/AffineOps.td",
Expand All @@ -448,7 +448,7 @@ gentbl_cc_library(
name = "FuncJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Func.inc.jl"
"Func.jl"
)
],
td_file = "@llvm-project//mlir:include/mlir/Dialect/Func/IR/FuncOps.td",
Expand All @@ -462,7 +462,7 @@ gentbl_cc_library(
name = "EnzymeJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"Enzyme.inc.jl"
"Enzyme.jl"
)
],
td_file = "@enzyme//:Enzyme/MLIR/Dialect/EnzymeOps.td",
Expand All @@ -476,7 +476,7 @@ gentbl_cc_library(
name = "StableHLOJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"StableHLO.inc.jl"
"StableHLO.jl"
)
],
td_file = "@stablehlo//:stablehlo/dialect/StablehloOps.td",
Expand All @@ -490,7 +490,7 @@ gentbl_cc_library(
name = "CHLOJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"CHLO.inc.jl"
"CHLO.jl"
)
],
td_file = "@stablehlo//:stablehlo/dialect/ChloOps.td",
Expand All @@ -504,7 +504,7 @@ gentbl_cc_library(
name = "VHLOJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"VHLO.inc.jl"
"VHLO.jl"
)
],
td_file = "@stablehlo//:stablehlo/dialect/VhloOps.td",
Expand Down
18 changes: 18 additions & 0 deletions deps/ReactantExtra/make-dialects.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
for file in [
"Builtin.jl",
"Arith.jl",
"Affine.jl",
"Func.jl",
"Enzyme.jl",
"StableHLO.jl",
"CHLO.jl",
"VHLO.jl",
]
run(
`bazel build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`,
)
Base.Filesystem.cp(
joinpath(@__DIR__, "bazel-bin", file),
joinpath(dirname(dirname(@__DIR__)), "src", "mlir", "Dialects", file),
)
end
29 changes: 0 additions & 29 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) # add Enzyme to environment stac

using Reactant
using Documenter
using Reactant_jll

struct TestRemote <: Remotes.Remote end
Remotes.repourl(::TestRemote) = "https://github.com/JuliaBinaryWrappers/Reactant_jll.jl"
function Remotes.fileurl(::TestRemote, ::Any, filename, linerange)
L1, L2 = first(linerange), last(linerange)
return "https://github.com/JuliaBinaryWrappers/Reactant_jll.jl/$(filename)#L$(L1)-$(L2)"
end
function Remotes.issueurl(::TestRemote, issue)
return "https://github.com/EnzymeAD/Reactant.jl/blob/$(issue)"
end

DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)

Expand All @@ -30,19 +19,6 @@ for (_, name) in examples
Literate.markdown(example_filepath, OUTPUT_DIR; documenter=true)
end

run(Cmd(`rm -rf .git`; dir=Reactant_jll.artifact_dir))
run(Cmd(`git init`; dir=Reactant_jll.artifact_dir))
run(Cmd(`git config user.name ReactantDocs`; dir=Reactant_jll.artifact_dir))
run(Cmd(`git config user.email [email protected]`; dir=Reactant_jll.artifact_dir))
run(
Cmd(
`git remote add origin https://github.com/EnzymeAD/Reactant.jl`;
dir=Reactant_jll.artifact_dir,
),
)
run(Cmd(`git add -A`; dir=Reactant_jll.artifact_dir))
run(Cmd(`git commit -m "Initial commit"`; dir=Reactant_jll.artifact_dir))

examples = [
title => joinpath("generated", string(name, ".md")) for (title, name) in examples
]
Expand All @@ -66,11 +42,6 @@ makedocs(;
Reactant.MLIR.Dialects.builtin,
],
authors="William Moses <[email protected]>, Valentin Churavy <[email protected]>",
remotes=Dict(
# Just non-repository directories
joinpath(@__DIR__, "..") => gh,
Reactant_jll.artifact_dir => TestRemote(),
),
sitename="Reactant.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
Expand Down
6 changes: 2 additions & 4 deletions src/mlir/Dialects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ function operandsegmentsizes(segments)
return namedattribute("operand_segment_sizes", Attribute(Int32.(segments)))
end

for path in readdir(Reactant_jll.artifact_dir; join=true)
if endswith("inc.jl")(path)
include(path)
end
for file in readdir(joinpath(@__DIR__, "Dialects"))
include(joinpath(@__DIR__, "Dialects", file))
end

end # module Dialects
Loading

1 comment on commit 2874e0d

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 2874e0d Previous: babfa2f Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1477123147 ns 1332320696 ns 1.11
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 214785620 ns 227642603 ns 0.94
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 6060496845 ns 5335100365 ns 1.14
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 36462862558 ns 13923447717 ns 2.62
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1187812104 ns 1316736699.5 ns 0.90
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8061529.5 ns 8429763.5 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1956949148 ns 1623537344.5 ns 1.21
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 4250420279 ns 2861216622 ns 1.49
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1262284206 ns 1322466226.5 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 21000200 ns 92359736 ns 0.23
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2531747932 ns 2131612752 ns 1.19
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6462833510 ns 5865040182.5 ns 1.10
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1257017538 ns 1307662641.5 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7263013.5 ns 7432819 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1799018534 ns 1467758920 ns 1.23
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 2829308508 ns 1560090503.5 ns 1.81
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1287683652 ns 1302764583 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 11229614 ns 11569359 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 2142481803 ns 1760726473 ns 1.22
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3163883256 ns 3822923609.5 ns 0.83
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1256704809 ns 1326375402 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 25518660.5 ns 91537054 ns 0.28
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2537928548 ns 2202666488.5 ns 1.15
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6538820532.5 ns 3885854683 ns 1.68
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1243726167 ns 1296753848.5 ns 0.96
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 50161174 ns 115901535 ns 0.43
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3347331233 ns 3060865793 ns 1.09
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 9372958155 ns 5422475355 ns 1.73
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1328529825 ns 1357473731 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 68217886 ns 128651601.5 ns 0.53
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3568972676 ns 3799896887 ns 0.94
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 11764280358 ns 9419165654 ns 1.25
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1323662542 ns 1356160172 ns 0.98
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 19588351 ns 92550180 ns 0.21
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 2203161073 ns 2407987798 ns 0.91
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 5053716394 ns 2697762090 ns 1.87

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.