From 2874e0df65ad889bbab3c34a82d8d9b2cea77ebb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Wed, 9 Oct 2024 12:10:58 -0400 Subject: [PATCH] Move Julia bindings to MLIR dialects out of JLL (#166) * Rename generated dialect file names * Add dialect binding file generator script * Generate dialect files * Automatize dialect regeneration * Fix paths * Fix docs generation --- .github/workflows/regenerate-dialects.yml | 32 + deps/ReactantExtra/BUILD | 16 +- deps/ReactantExtra/make-dialects.jl | 18 + docs/make.jl | 29 - src/mlir/Dialects.jl | 6 +- src/mlir/Dialects/Affine.jl | 696 ++++ src/mlir/Dialects/Arith.jl | 1676 ++++++++++ src/mlir/Dialects/Builtin.jl | 93 + src/mlir/Dialects/CHLO.jl | 1158 +++++++ src/mlir/Dialects/Enzyme.jl | 191 ++ src/mlir/Dialects/Func.jl | 194 ++ src/mlir/Dialects/StableHLO.jl | 3720 +++++++++++++++++++++ src/mlir/Dialects/VHLO.jl | 2008 +++++++++++ 13 files changed, 9796 insertions(+), 41 deletions(-) create mode 100644 .github/workflows/regenerate-dialects.yml create mode 100644 deps/ReactantExtra/make-dialects.jl create mode 100755 src/mlir/Dialects/Affine.jl create mode 100755 src/mlir/Dialects/Arith.jl create mode 100755 src/mlir/Dialects/Builtin.jl create mode 100755 src/mlir/Dialects/CHLO.jl create mode 100755 src/mlir/Dialects/Enzyme.jl create mode 100755 src/mlir/Dialects/Func.jl create mode 100755 src/mlir/Dialects/StableHLO.jl create mode 100755 src/mlir/Dialects/VHLO.jl diff --git a/.github/workflows/regenerate-dialects.yml b/.github/workflows/regenerate-dialects.yml new file mode 100644 index 000000000..eca4dfb6d --- /dev/null +++ b/.github/workflows/regenerate-dialects.yml @@ -0,0 +1,32 @@ +name: Regenerate MLIR Dialects +on: + schedule: + - cron: '0 0 * * *' + workflow_dispatch: +jobs: + make: + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + steps: + - uses: julia-actions/setup-julia@v1 + with: + version: '1.10' + - uses: actions/checkout@v4 + with: + ref: main + - run: julia deps/ReactantExtra/make-dialects.jl + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v6 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Regenerate MLIR Dialects + title: 'Regenerate MLIR Dialects' + branch: regenerate-dialects + delete-branch: true + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index a60874e43..363233fda 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -406,7 +406,7 @@ gentbl_cc_library( name = "BuiltinJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "Builtin.inc.jl" + "Builtin.jl" ) ], td_file = "@llvm-project//mlir:include/mlir/IR/BuiltinOps.td", @@ -420,7 +420,7 @@ gentbl_cc_library( name = "ArithJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "Arith.inc.jl" + "Arith.jl" ) ], td_file = "@llvm-project//mlir:include/mlir/Dialect/Arith/IR/ArithOps.td", @@ -434,7 +434,7 @@ gentbl_cc_library( name = "AffineJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "Affine.inc.jl" + "Affine.jl" ) ], td_file = "@llvm-project//mlir:include/mlir/Dialect/Affine/IR/AffineOps.td", @@ -448,7 +448,7 @@ gentbl_cc_library( name = "FuncJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "Func.inc.jl" + "Func.jl" ) ], td_file = "@llvm-project//mlir:include/mlir/Dialect/Func/IR/FuncOps.td", @@ -462,7 +462,7 @@ gentbl_cc_library( name = "EnzymeJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "Enzyme.inc.jl" + "Enzyme.jl" ) ], td_file = "@enzyme//:Enzyme/MLIR/Dialect/EnzymeOps.td", @@ -476,7 +476,7 @@ gentbl_cc_library( name = "StableHLOJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "StableHLO.inc.jl" + "StableHLO.jl" ) ], td_file = "@stablehlo//:stablehlo/dialect/StablehloOps.td", @@ -490,7 +490,7 @@ gentbl_cc_library( name = "CHLOJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "CHLO.inc.jl" + "CHLO.jl" ) ], td_file = "@stablehlo//:stablehlo/dialect/ChloOps.td", @@ -504,7 +504,7 @@ gentbl_cc_library( name = "VHLOJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "VHLO.inc.jl" + "VHLO.jl" ) ], td_file = "@stablehlo//:stablehlo/dialect/VhloOps.td", diff --git a/deps/ReactantExtra/make-dialects.jl b/deps/ReactantExtra/make-dialects.jl new file mode 100644 index 000000000..4654721b3 --- /dev/null +++ b/deps/ReactantExtra/make-dialects.jl @@ -0,0 +1,18 @@ +for file in [ + "Builtin.jl", + "Arith.jl", + "Affine.jl", + "Func.jl", + "Enzyme.jl", + "StableHLO.jl", + "CHLO.jl", + "VHLO.jl", +] + run( + `bazel build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`, + ) + Base.Filesystem.cp( + joinpath(@__DIR__, "bazel-bin", file), + joinpath(dirname(dirname(@__DIR__)), "src", "mlir", "Dialects", file), + ) +end diff --git a/docs/make.jl b/docs/make.jl index 3063a9bcd..d57d1751e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,17 +2,6 @@ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) # add Enzyme to environment stac using Reactant using Documenter -using Reactant_jll - -struct TestRemote <: Remotes.Remote end -Remotes.repourl(::TestRemote) = "https://github.com/JuliaBinaryWrappers/Reactant_jll.jl" -function Remotes.fileurl(::TestRemote, ::Any, filename, linerange) - L1, L2 = first(linerange), last(linerange) - return "https://github.com/JuliaBinaryWrappers/Reactant_jll.jl/$(filename)#L$(L1)-$(L2)" -end -function Remotes.issueurl(::TestRemote, issue) - return "https://github.com/EnzymeAD/Reactant.jl/blob/$(issue)" -end DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true) @@ -30,19 +19,6 @@ for (_, name) in examples Literate.markdown(example_filepath, OUTPUT_DIR; documenter=true) end -run(Cmd(`rm -rf .git`; dir=Reactant_jll.artifact_dir)) -run(Cmd(`git init`; dir=Reactant_jll.artifact_dir)) -run(Cmd(`git config user.name ReactantDocs`; dir=Reactant_jll.artifact_dir)) -run(Cmd(`git config user.email ReactantDocs@wsmoses.com`; dir=Reactant_jll.artifact_dir)) -run( - Cmd( - `git remote add origin https://github.com/EnzymeAD/Reactant.jl`; - dir=Reactant_jll.artifact_dir, - ), -) -run(Cmd(`git add -A`; dir=Reactant_jll.artifact_dir)) -run(Cmd(`git commit -m "Initial commit"`; dir=Reactant_jll.artifact_dir)) - examples = [ title => joinpath("generated", string(name, ".md")) for (title, name) in examples ] @@ -66,11 +42,6 @@ makedocs(; Reactant.MLIR.Dialects.builtin, ], authors="William Moses , Valentin Churavy ", - remotes=Dict( - # Just non-repository directories - joinpath(@__DIR__, "..") => gh, - Reactant_jll.artifact_dir => TestRemote(), - ), sitename="Reactant.jl", format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", diff --git a/src/mlir/Dialects.jl b/src/mlir/Dialects.jl index 1cdd78347..50a90efff 100644 --- a/src/mlir/Dialects.jl +++ b/src/mlir/Dialects.jl @@ -16,10 +16,8 @@ function operandsegmentsizes(segments) return namedattribute("operand_segment_sizes", Attribute(Int32.(segments))) end -for path in readdir(Reactant_jll.artifact_dir; join=true) - if endswith("inc.jl")(path) - include(path) - end +for file in readdir(joinpath(@__DIR__, "Dialects")) + include(joinpath(@__DIR__, "Dialects", file)) end end # module Dialects diff --git a/src/mlir/Dialects/Affine.jl b/src/mlir/Dialects/Affine.jl new file mode 100755 index 000000000..ffc047192 --- /dev/null +++ b/src/mlir/Dialects/Affine.jl @@ -0,0 +1,696 @@ +module affine +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`apply` + +The `affine.apply` operation applies an [affine mapping](#affine-maps) +to a list of SSA values, yielding a single SSA value. The number of +dimension and symbol arguments to `affine.apply` must be equal to the +respective number of dimensional and symbolic inputs to the affine mapping; +the affine mapping has to be one-dimensional, and so the `affine.apply` +operation always returns one value. The input operands and result must all +have ‘index’ type. + +# Example + +```mlir +#map10 = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)> +... +%1 = affine.apply #map10 (%s, %t) + +// Inline example. +%2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n] +``` +""" +function apply(mapOperands::Vector{Value}; result_0=nothing::Union{Nothing, IR.Type}, map, location=Location()) + op_ty_results = IR.Type[] + operands = Value[mapOperands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "affine.apply", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`delinearize_index` + +The `affine.delinearize_index` operation takes a single index value and +calculates the multi-index according to the given basis. + +# Example + +``` +%indices:3 = affine.delinearize_index %linear_index into (%c16, %c224, %c224) : index, index, index +``` + +In the above example, `%indices:3` conceptually holds the following: + +``` +#map0 = affine_map<()[s0] -> (s0 floordiv 50176)> +#map1 = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)> +#map2 = affine_map<()[s0] -> (s0 mod 224)> +%indices_0 = affine.apply #map0()[%linear_index] +%indices_1 = affine.apply #map1()[%linear_index] +%indices_2 = affine.apply #map2()[%linear_index] +``` +""" +function delinearize_index(linear_index::Value, basis::Vector{Value}; multi_index=nothing::Union{Nothing, Vector{IR.Type}}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[linear_index, basis..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(multi_index) && push!(op_ty_results, multi_index...) + + create_operation( + "affine.delinearize_index", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`for_` + +# Syntax + +``` +operation ::= `affine.for` ssa-id `=` lower-bound `to` upper-bound + (`step` integer-literal)? `{` op* `}` + +lower-bound ::= `max`? affine-map-attribute dim-and-symbol-use-list | shorthand-bound +upper-bound ::= `min`? affine-map-attribute dim-and-symbol-use-list | shorthand-bound +shorthand-bound ::= ssa-id | `-`? integer-literal +``` + +The `affine.for` operation represents an affine loop nest. It has one region +containing its body. This region must contain one block that terminates with +[`affine.yield`](#affineyield-mliraffineyieldop). *Note:* when +`affine.for` is printed in custom format, the terminator is omitted. The +block has one argument of [`index`](Builtin.md/#indextype) type that +represents the induction variable of the loop. + +The `affine.for` operation executes its body a number of times iterating +from a lower bound to an upper bound by a stride. The stride, represented by +`step`, is a positive constant integer which defaults to \"1\" if not present. +The lower and upper bounds specify a half-open range: the range includes the +lower bound but does not include the upper bound. + +The lower and upper bounds of a `affine.for` operation are represented as an +application of an affine mapping to a list of SSA values passed to the map. +The [same restrictions](#restrictions-on-dimensions-and-symbols) hold for +these SSA values as for all bindings of SSA values to dimensions and +symbols. + +The affine mappings for the bounds may return multiple results, in which +case the `max`/`min` keywords are required (for the lower/upper bound +respectively), and the bound is the maximum/minimum of the returned values. +There is no semantic ambiguity, but MLIR syntax requires the use of these +keywords to make things more obvious to human readers. + +Many upper and lower bounds are simple, so MLIR accepts two custom form +syntaxes: the form that accepts a single \'ssa-id\' (e.g. `%N`) is shorthand +for applying that SSA value to a function that maps a single symbol to +itself, e.g., `()[s]->(s)()[%N]`. The integer literal form (e.g. `-42`) is +shorthand for a nullary mapping function that returns the constant value +(e.g. `()->(-42)()`). + +Example showing reverse iteration of the inner loop: + +```mlir +#map57 = affine_map<(d0)[s0] -> (s0 - d0 - 1)> + +func.func @simple_example(%A: memref, %B: memref) { + %N = dim %A, 0 : memref + affine.for %i = 0 to %N step 1 { + affine.for %j = 0 to %N { // implicitly steps by 1 + %0 = affine.apply #map57(%j)[%N] + %tmp = call @F1(%A, %i, %0) : (memref, index, index)->(f32) + call @F2(%tmp, %B, %i, %0) : (f32, memref, index, index)->() + } + } + return +} +``` +`affine.for` can also operate on loop-carried variables (`iter_args`) and +return the final values after loop termination. The initial values of the +variables are passed as additional SSA operands to the `affine.for` +following the operands for the loop\'s lower and upper bounds. The +operation\'s region has equivalent arguments for each variable representing +the value of the variable at the current iteration. + +The region must terminate with an `affine.yield` that passes all the current +iteration variables to the next iteration, or to the `affine.for`\'s results +if at the last iteration. For `affine.for`\'s that execute zero iterations, the +initial values of the loop-carried variables (corresponding to the SSA +operands) will be the op\'s results. + +For example, to sum-reduce a memref: + + ```mlir +func.func @reduce(%buffer: memref<1024xf32>) -> (f32) { + // Initial sum set to 0. + %sum_0 = arith.constant 0.0 : f32 + // iter_args binds initial values to the loop\'s region arguments. + %sum = affine.for %i = 0 to 10 step 2 + iter_args(%sum_iter = %sum_0) -> (f32) { + %t = affine.load %buffer[%i] : memref<1024xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + // Yield current iteration sum to next iteration %sum_iter or to %sum + // if final iteration. + affine.yield %sum_next : f32 + } + return %sum : f32 +} +``` + +```mlir +%res:2 = affine.for %i = 0 to 128 iter_args(%arg0 = %init0, %arg1 = %init1) + -> (index, index) { + %y0 = arith.addi %arg0, %c1 : index + %y1 = arith.addi %arg1, %c2 : index + affine.yield %y0, %y1 : index, index +} +``` +If the `affine.for` defines any values, a yield terminator must be +explicitly present. The number and types of the \"affine.for\" results must +match the initial values in the `iter_args` binding and the yield operands. +""" +function for_(lowerBoundOperands::Vector{Value}, upperBoundOperands::Vector{Value}, inits::Vector{Value}; results::Vector{IR.Type}, lowerBoundMap, upperBoundMap, step, region::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[lowerBoundOperands..., upperBoundOperands..., inits..., ] + owned_regions = Region[region, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("lowerBoundMap", lowerBoundMap), namedattribute("upperBoundMap", upperBoundMap), namedattribute("step", step), ] + push!(attributes, operandsegmentsizes([length(lowerBoundOperands), length(upperBoundOperands), length(inits), ])) + + create_operation( + "affine.for", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`if_` + +# Syntax + +``` +operation ::= `affine.if` if-op-cond `{` op* `}` (`else` `{` op* `}`)? +if-op-cond ::= integer-set-attr dim-and-symbol-use-list +``` + +The `affine.if` operation restricts execution to a subset of the loop +iteration space defined by an integer set (a conjunction of affine +constraints). A single `affine.if` may end with an optional `else` clause. + +The condition of the `affine.if` is represented by an +[integer set](#integer-sets) (a conjunction of affine constraints), +and the SSA values bound to the dimensions and symbols in the integer set. +The [same restrictions](#restrictions-on-dimensions-and-symbols) hold for +these SSA values as for all bindings of SSA values to dimensions and +symbols. + +The `affine.if` operation contains two regions for the \"then\" and \"else\" +clauses. `affine.if` may return results that are defined in its regions. +The values defined are determined by which execution path is taken. Each +region of the `affine.if` must contain a single block with no arguments, +and be terminated by `affine.yield`. If `affine.if` defines no values, +the `affine.yield` can be left out, and will be inserted implicitly. +Otherwise, it must be explicit. If no values are defined, the else block +may be empty (i.e. contain no blocks). + +# Example + +```mlir +#set = affine_set<(d0, d1)[s0]: (d0 - 10 >= 0, s0 - d0 - 9 >= 0, + d1 - 10 >= 0, s0 - d1 - 9 >= 0)> +func.func @reduced_domain_example(%A, %X, %N) : (memref<10xi32>, i32, i32) { + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %0 = affine.apply #map42(%j) + %tmp = call @S1(%X, %i, %0) + affine.if #set(%i, %j)[%N] { + %1 = affine.apply #map43(%i, %j) + call @S2(%tmp, %A, %i, %1) + } + } + } + return +} +``` + +Example with an explicit yield (initialization with edge padding): + +```mlir +#interior = affine_set<(i, j) : (i - 1 >= 0, j - 1 >= 0, 10 - i >= 0, 10 - j >= 0)> (%i, %j) +func.func @pad_edges(%I : memref<10x10xf32>) -> (memref<12x12xf32) { + %O = alloc memref<12x12xf32> + affine.parallel (%i, %j) = (0, 0) to (12, 12) { + %1 = affine.if #interior (%i, %j) { + %2 = load %I[%i - 1, %j - 1] : memref<10x10xf32> + affine.yield %2 + } else { + %2 = arith.constant 0.0 : f32 + affine.yield %2 : f32 + } + affine.store %1, %O[%i, %j] : memref<12x12xf32> + } + return %O +} +``` +""" +function if_(operand_0::Vector{Value}; results::Vector{IR.Type}, thenRegion::Region, elseRegion::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operand_0..., ] + owned_regions = Region[thenRegion, elseRegion, ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "affine.if", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`load` + +# Syntax + +``` +operation ::= ssa-id `=` `affine.load` ssa-use `[` multi-dim-affine-map-of-ssa-ids `]` `:` memref-type +``` + +The `affine.load` op reads an element from a memref, where the index +for each memref dimension is an affine expression of loop induction +variables and symbols. The output of `affine.load` is a new value with the +same type as the elements of the memref. An affine expression of loop IVs +and symbols must be specified for each dimension of the memref. The keyword +`symbol` can be used to indicate SSA identifiers which are symbolic. + +Example 1: + +```mlir +%1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +``` + +Example 2: Uses `symbol` keyword for symbols `%n` and `%m`. + +```mlir +%1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32> +``` +""" +function load(memref::Value, indices::Vector{Value}; result::IR.Type, map, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[memref, indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + + create_operation( + "affine.load", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`max` + +The `affine.max` operation computes the maximum value result from a multi-result +affine map. + +# Example + +```mlir +%0 = affine.max (d0) -> (1000, d0 + 512) (%i0) : index +``` +""" +function max(operands::Vector{Value}; result_0=nothing::Union{Nothing, IR.Type}, map, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "affine.max", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`min` + +# Syntax + +``` +operation ::= ssa-id `=` `affine.min` affine-map-attribute dim-and-symbol-use-list +``` + +The `affine.min` operation applies an [affine mapping](#affine-expressions) +to a list of SSA values, and returns the minimum value of all result +expressions. The number of dimension and symbol arguments to `affine.min` +must be equal to the respective number of dimensional and symbolic inputs to +the affine mapping; the `affine.min` operation always returns one value. The +input operands and result must all have \'index\' type. + +# Example + +```mlir +%0 = affine.min affine_map<(d0)[s0] -> (1000, d0 + 512, s0)> (%arg0)[%arg1] +``` +""" +function min(operands::Vector{Value}; result_0=nothing::Union{Nothing, IR.Type}, map, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "affine.min", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`parallel` + +The `affine.parallel` operation represents a hyper-rectangular affine +parallel band, defining zero or more SSA values for its induction variables. +It has one region capturing the parallel band body. The induction variables +are represented as arguments of this region. These SSA values always have +type index, which is the size of the machine word. The strides, represented +by steps, are positive constant integers which defaults to \"1\" if not +present. The lower and upper bounds specify a half-open range: the range +includes the lower bound but does not include the upper bound. The body +region must contain exactly one block that terminates with `affine.yield`. + +The lower and upper bounds of a parallel operation are represented as an +application of an affine mapping to a list of SSA values passed to the map. +The same restrictions hold for these SSA values as for all bindings of SSA +values to dimensions and symbols. The list of expressions in each map is +interpreted according to the respective bounds group attribute. If a single +expression belongs to the group, then the result of this expression is taken +as a lower(upper) bound of the corresponding loop induction variable. If +multiple expressions belong to the group, then the lower(upper) bound is the +max(min) of these values obtained from these expressions. The loop band has +as many loops as elements in the group bounds attributes. + +Each value yielded by `affine.yield` will be accumulated/reduced via one of +the reduction methods defined in the AtomicRMWKind enum. The order of +reduction is unspecified, and lowering may produce any valid ordering. +Loops with a 0 trip count will produce as a result the identity value +associated with each reduction (i.e. 0.0 for addf, 1.0 for mulf). Assign +reductions for loops with a trip count != 1 produces undefined results. + +Note: Calling `AffineParallelOp::build` will create the required region and +block, and insert the required terminator if it is trivial (i.e. no values +are yielded). Parsing will also create the required region, block, and +terminator, even when they are missing from the textual representation. + +Example (3x3 valid convolution): + +```mlir +func.func @conv_2d(%D : memref<100x100xf32>, %K : memref<3x3xf32>) -> (memref<98x98xf32>) { + %O = memref.alloc() : memref<98x98xf32> + affine.parallel (%x, %y) = (0, 0) to (98, 98) { + %0 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce (\"addf\") -> f32 { + %1 = affine.load %D[%x + %kx, %y + %ky] : memref<100x100xf32> + %2 = affine.load %K[%kx, %ky] : memref<3x3xf32> + %3 = arith.mulf %1, %2 : f32 + affine.yield %3 : f32 + } + affine.store %0, %O[%x, %y] : memref<98x98xf32> + } + return %O : memref<98x98xf32> +} +``` + +Example (tiling by potentially imperfectly dividing sizes): + +```mlir +affine.parallel (%ii, %jj) = (0, 0) to (%N, %M) step (32, 32) { + affine.parallel (%i, %j) = (%ii, %jj) + to (min(%ii + 32, %N), min(%jj + 32, %M)) { + call @f(%i, %j) : (index, index) -> () + } +} +``` +""" +function parallel(mapOperands::Vector{Value}; results::Vector{IR.Type}, reductions, lowerBoundsMap, lowerBoundsGroups, upperBoundsMap, upperBoundsGroups, steps, region::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[mapOperands..., ] + owned_regions = Region[region, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("reductions", reductions), namedattribute("lowerBoundsMap", lowerBoundsMap), namedattribute("lowerBoundsGroups", lowerBoundsGroups), namedattribute("upperBoundsMap", upperBoundsMap), namedattribute("upperBoundsGroups", upperBoundsGroups), namedattribute("steps", steps), ] + + create_operation( + "affine.parallel", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`prefetch` + +The `affine.prefetch` op prefetches data from a memref location described +with an affine subscript similar to affine.load, and has three attributes: +a read/write specifier, a locality hint, and a cache type specifier as shown +below: + +```mlir +affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> +``` + +The read/write specifier is either \'read\' or \'write\', the locality hint +specifier ranges from locality<0> (no locality) to locality<3> (extremely +local keep in cache). The cache type specifier is either \'data\' or \'instr\' +and specifies whether the prefetch is performed on data cache or on +instruction cache. +""" +function prefetch(memref::Value, indices::Vector{Value}; isWrite, localityHint, isDataCache, map, location=Location()) + op_ty_results = IR.Type[] + operands = Value[memref, indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("isWrite", isWrite), namedattribute("localityHint", localityHint), namedattribute("isDataCache", isDataCache), namedattribute("map", map), ] + + create_operation( + "affine.prefetch", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`store` + +# Syntax + +``` +operation ::= `affine.store` ssa-use, ssa-use `[` multi-dim-affine-map-of-ssa-ids `]` `:` memref-type +``` + +The `affine.store` op writes an element to a memref, where the index +for each memref dimension is an affine expression of loop induction +variables and symbols. The `affine.store` op stores a new value which is the +same type as the elements of the memref. An affine expression of loop IVs +and symbols must be specified for each dimension of the memref. The keyword +`symbol` can be used to indicate SSA identifiers which are symbolic. + +Example 1: + +```mlir +affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +``` + +Example 2: Uses `symbol` keyword for symbols `%n` and `%m`. + +```mlir +affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32> +``` +""" +function store(value::Value, memref::Value, indices::Vector{Value}; map, location=Location()) + op_ty_results = IR.Type[] + operands = Value[value, memref, indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + + create_operation( + "affine.store", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`vector_load` + +The `affine.vector_load` is the vector counterpart of +[affine.load](#affineload-mliraffineloadop). It reads a slice from a +[MemRef](Builtin.md/#memreftype), supplied as its first operand, +into a [vector](Builtin.md/#vectortype) of the same base elemental type. +The index for each memref dimension is an affine expression of loop induction +variables and symbols. These indices determine the start position of the read +within the memref. The shape of the return vector type determines the shape of +the slice read from the memref. This slice is contiguous along the respective +dimensions of the shape. Strided vector loads will be supported in the future. +An affine expression of loop IVs and symbols must be specified for each +dimension of the memref. The keyword `symbol` can be used to indicate SSA +identifiers which are symbolic. + +Example 1: 8-wide f32 vector load. + +```mlir +%1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32> +``` + +Example 2: 4-wide f32 vector load. Uses `symbol` keyword for symbols `%n` and `%m`. + +```mlir +%1 = affine.vector_load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32> +``` + +Example 3: 2-dim f32 vector load. + +```mlir +%1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> +``` + +TODOs: +* Add support for strided vector loads. +* Consider adding a permutation map to permute the slice that is read from memory +(see [vector.transfer_read](../Vector/#vectortransfer_read-mlirvectortransferreadop)). +""" +function vector_load(memref::Value, indices::Vector{Value}; result::IR.Type, map, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[memref, indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + + create_operation( + "affine.vector_load", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`vector_store` + +The `affine.vector_store` is the vector counterpart of +[affine.store](#affinestore-mliraffinestoreop). It writes a +[vector](Builtin.md/#vectortype), supplied as its first operand, +into a slice within a [MemRef](Builtin.md/#memreftype) of the same base +elemental type, supplied as its second operand. +The index for each memref dimension is an affine expression of loop +induction variables and symbols. These indices determine the start position +of the write within the memref. The shape of th input vector determines the +shape of the slice written to the memref. This slice is contiguous along the +respective dimensions of the shape. Strided vector stores will be supported +in the future. +An affine expression of loop IVs and symbols must be specified for each +dimension of the memref. The keyword `symbol` can be used to indicate SSA +identifiers which are symbolic. + +Example 1: 8-wide f32 vector store. + +```mlir +affine.vector_store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32> +``` + +Example 2: 4-wide f32 vector store. Uses `symbol` keyword for symbols `%n` and `%m`. + +```mlir +affine.vector_store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32> +``` + +Example 3: 2-dim f32 vector store. + +```mlir +affine.vector_store %v0, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> +``` + +TODOs: +* Add support for strided vector stores. +* Consider adding a permutation map to permute the slice that is written to memory +(see [vector.transfer_write](../Vector/#vectortransfer_write-mlirvectortransferwriteop)). +""" +function vector_store(value::Value, memref::Value, indices::Vector{Value}; map, location=Location()) + op_ty_results = IR.Type[] + operands = Value[value, memref, indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("map", map), ] + + create_operation( + "affine.vector_store", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`yield` + +The `affine.yield` yields zero or more SSA values from an affine op region and +terminates the region. The semantics of how the values yielded are used +is defined by the parent operation. +If `affine.yield` has any operands, the operands must match the parent +operation\'s results. +If the parent operation defines no values, then the `affine.yield` may be +left out in the custom syntax and the builders will insert one implicitly. +Otherwise, it has to be present in the syntax to indicate which values are +yielded. +""" +function yield(operands::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "affine.yield", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +end # affine diff --git a/src/mlir/Dialects/Arith.jl b/src/mlir/Dialects/Arith.jl new file mode 100755 index 000000000..bdf9fc2f3 --- /dev/null +++ b/src/mlir/Dialects/Arith.jl @@ -0,0 +1,1676 @@ +module arith +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`addf` + +The `addf` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be a floating point +scalar type, a vector whose element type is a floating point type, or a +floating point tensor. + +# Example + +```mlir +// Scalar addition. +%a = arith.addf %b, %c : f64 + +// SIMD vector addition, e.g. for Intel SSE. +%f = arith.addf %g, %h : vector<4xf32> + +// Tensor addition. +%x = arith.addf %y, %z : tensor<4x?xbf16> +``` + +TODO: In the distant future, this will accept optional attributes for fast +math, contraction, rounding mode, and other controls. +""" +function addf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.addf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`addi` + +Performs N-bit addition on the operands. The operands are interpreted as +unsigned bitvectors. The result is represented by a bitvector containing the +mathematical value of the addition modulo 2^n, where `n` is the bitwidth. +Because `arith` integers use a two\'s complement representation, this operation +is applicable on both signed and unsigned integer operands. + +The `addi` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be an integer scalar type, +a vector whose element type is integer, or a tensor of integers. + +This op supports `nuw`/`nsw` overflow flags which stands stand for +\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or +`nsw` flags are present, and an unsigned/signed overflow occurs +(respectively), the result is poison. + +# Example + +```mlir +// Scalar addition. +%a = arith.addi %b, %c : i64 + +// Scalar addition with overflow flags. +%a = arith.addi %b, %c overflow : i64 + +// SIMD vector element-wise addition. +%f = arith.addi %g, %h : vector<4xi32> + +// Tensor element-wise addition. +%x = arith.addi %y, %z : tensor<4x?xi8> +``` +""" +function addi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, overflowFlags=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(overflowFlags) && push!(attributes, namedattribute("overflowFlags", overflowFlags)) + + create_operation( + "arith.addi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`addui_extended` + +Performs (N+1)-bit addition on zero-extended operands. Returns two results: +the N-bit sum (same type as both operands), and the overflow bit +(boolean-like), where `1` indicates unsigned addition overflow, while `0` +indicates no overflow. + +# Example + +```mlir +// Scalar addition. +%sum, %overflow = arith.addui_extended %b, %c : i64, i1 + +// Vector element-wise addition. +%d:2 = arith.addui_extended %e, %f : vector<4xi32>, vector<4xi1> + +// Tensor element-wise addition. +%x:2 = arith.addui_extended %y, %z : tensor<4x?xi8>, tensor<4x?xi1> +``` +""" +function addui_extended(lhs::Value, rhs::Value; sum::IR.Type, overflow::IR.Type, location=Location()) + op_ty_results = IR.Type[sum, overflow, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.addui_extended", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`andi` + +The `andi` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be an integer scalar +type, a vector whose element type is integer, or a tensor of integers. It +has no standard attributes. + +# Example + +```mlir +// Scalar integer bitwise and. +%a = arith.andi %b, %c : i64 + +// SIMD vector element-wise bitwise integer and. +%f = arith.andi %g, %h : vector<4xi32> + +// Tensor element-wise bitwise integer and. +%x = arith.andi %y, %z : tensor<4x?xi8> +``` +""" +function andi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.andi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`bitcast` + +Bitcast an integer or floating point value to an integer or floating point +value of equal bit width. When operating on vectors, casts elementwise. + +Note that this implements a logical bitcast independent of target +endianness. This allows constant folding without target information and is +consitent with the bitcast constant folders in LLVM (see +https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168) +For targets where the source and target type have the same endianness (which +is the standard), this cast will also change no bits at runtime, but it may +still require an operation, for example if the machine has different +floating point and integer register files. For targets that have a different +endianness for the source and target types (e.g. float is big-endian and +integer is little-endian) a proper lowering would add operations to swap the +order of words in addition to the bitcast. +""" +function bitcast(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.bitcast", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`ceildivsi` + +Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. + +Divison by zero, or signed division overflow (minimum value divided by -1) +is undefined behavior. When applied to `vector` and `tensor` values, the +behavior is undefined if _any_ of its elements are divided by zero or has a +signed division overflow. + +# Example + +```mlir +// Scalar signed integer division. +%a = arith.ceildivsi %b, %c : i64 +``` +""" +function ceildivsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.ceildivsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`ceildivui` + +Unsigned integer division. Rounds towards positive infinity. Treats the +leading bit as the most significant, i.e. for `i16` given two\'s complement +representation, `6 / -2 = 6 / (2^16 - 2) = 1`. + +Division by zero is undefined behavior. When applied to `vector` and +`tensor` values, the behavior is undefined if _any_ elements are divided by +zero. + +# Example + +```mlir +// Scalar unsigned integer division. +%a = arith.ceildivui %b, %c : i64 +``` +""" +function ceildivui(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.ceildivui", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`cmpf` + +The `cmpf` operation compares its two operands according to the float +comparison rules and the predicate specified by the respective attribute. +The predicate defines the type of comparison: (un)orderedness, (in)equality +and signed less/greater than (or equal to) as well as predicates that are +always true or false. The operands must have the same type, and this type +must be a float type, or a vector or tensor thereof. The result is an i1, +or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, +the operands are always treated as signed. The u prefix indicates +*unordered* comparison, not unsigned comparison, so \"une\" means unordered or +not equal. For the sake of readability by humans, custom assembly form for +the operation uses a string-typed attribute for the predicate. The value of +this attribute corresponds to lower-cased name of the predicate constant, +e.g., \"one\" means \"ordered not equal\". The string representation of the +attribute is merely a syntactic sugar and is converted to an integer +attribute by the parser. + +# Example + +```mlir +%r1 = arith.cmpf oeq, %0, %1 : f32 +%r2 = arith.cmpf ult, %0, %1 : tensor<42x42xf64> +%r3 = \"arith.cmpf\"(%0, %1) {predicate: 0} : (f8, f8) -> i1 +``` +""" +function cmpf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, predicate, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("predicate", predicate), ] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.cmpf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`cmpi` + +The `cmpi` operation is a generic comparison for integer-like types. Its two +arguments can be integers, vectors or tensors thereof as long as their types +match. The operation produces an i1 for the former case, a vector or a +tensor of i1 with the same shape as inputs in the other cases. + +Its first argument is an attribute that defines which type of comparison is +performed. The following comparisons are supported: + +- equal (mnemonic: `\"eq\"`; integer value: `0`) +- not equal (mnemonic: `\"ne\"`; integer value: `1`) +- signed less than (mnemonic: `\"slt\"`; integer value: `2`) +- signed less than or equal (mnemonic: `\"sle\"`; integer value: `3`) +- signed greater than (mnemonic: `\"sgt\"`; integer value: `4`) +- signed greater than or equal (mnemonic: `\"sge\"`; integer value: `5`) +- unsigned less than (mnemonic: `\"ult\"`; integer value: `6`) +- unsigned less than or equal (mnemonic: `\"ule\"`; integer value: `7`) +- unsigned greater than (mnemonic: `\"ugt\"`; integer value: `8`) +- unsigned greater than or equal (mnemonic: `\"uge\"`; integer value: `9`) + +The result is `1` if the comparison is true and `0` otherwise. For vector or +tensor operands, the comparison is performed elementwise and the element of +the result indicates whether the comparison is true for the operand elements +with the same indices as those of the result. + +Note: while the custom assembly form uses strings, the actual underlying +attribute has integer type (or rather enum class in C++ code) as seen from +the generic assembly form. String literals are used to improve readability +of the IR by humans. + +This operation only applies to integer-like operands, but not floats. The +main reason being that comparison operations have diverging sets of +attributes: integers require sign specification while floats require various +floating point-related particularities, e.g., `-ffast-math` behavior, +IEEE754 compliance, etc +([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)). +The type of comparison is specified as attribute to avoid introducing ten +similar operations, taking into account that they are often implemented +using the same operation downstream +([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The +separation between signed and unsigned order comparisons is necessary +because of integers being signless. The comparison operation must know how +to interpret values with the foremost bit being set: negatives in two\'s +complement or large positives +([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)). + +# Example + +```mlir +// Custom form of scalar \"signed less than\" comparison. +%x = arith.cmpi slt, %lhs, %rhs : i32 + +// Generic form of the same operation. +%x = \"arith.cmpi\"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 + +// Custom form of vector equality comparison. +%x = arith.cmpi eq, %lhs, %rhs : vector<4xi64> + +// Generic form of the same operation. +%x = \"arith.cmpi\"(%lhs, %rhs) {predicate = 0 : i64} + : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> +``` +""" +function cmpi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, predicate, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("predicate", predicate), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.cmpi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`constant` + +The `constant` operation produces an SSA value equal to some integer or +floating-point constant specified by an attribute. This is the way MLIR +forms simple integer and floating point constants. + +# Example + +``` +// Integer constant +%1 = arith.constant 42 : i32 + +// Equivalent generic form +%1 = \"arith.constant\"() {value = 42 : i32} : () -> i32 +``` +""" +function constant(; result=nothing::Union{Nothing, IR.Type}, value, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.constant", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + + +function divf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.divf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`divsi` + +Signed integer division. Rounds towards zero. Treats the leading bit as +sign, i.e. `6 / -2 = -3`. + +Divison by zero, or signed division overflow (minimum value divided by -1) +is undefined behavior. When applied to `vector` and `tensor` values, the +behavior is undefined if _any_ of its elements are divided by zero or has a +signed division overflow. + +# Example + +```mlir +// Scalar signed integer division. +%a = arith.divsi %b, %c : i64 + +// SIMD vector element-wise division. +%f = arith.divsi %g, %h : vector<4xi32> + +// Tensor element-wise integer division. +%x = arith.divsi %y, %z : tensor<4x?xi8> +``` +""" +function divsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.divsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`divui` + +Unsigned integer division. Rounds towards zero. Treats the leading bit as +the most significant, i.e. for `i16` given two\'s complement representation, +`6 / -2 = 6 / (2^16 - 2) = 0`. + +Division by zero is undefined behavior. When applied to `vector` and +`tensor` values, the behavior is undefined if _any_ elements are divided by +zero. + +# Example + +```mlir +// Scalar unsigned integer division. +%a = arith.divui %b, %c : i64 + +// SIMD vector element-wise division. +%f = arith.divui %g, %h : vector<4xi32> + +// Tensor element-wise integer division. +%x = arith.divui %y, %z : tensor<4x?xi8> +``` +""" +function divui(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.divui", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`extf` + +Cast a floating-point value to a larger floating-point-typed value. +The destination type must to be strictly wider than the source type. +When operating on vectors, casts elementwise. +""" +function extf(in::Value; out::IR.Type, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.extf", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`extsi` + +The integer sign extension operation takes an integer input of +width M and an integer destination type of width N. The destination +bit-width must be larger than the input bit-width (N > M). +The top-most (N - M) bits of the output are filled with copies +of the most-significant bit of the input. + +# Example + +```mlir +%1 = arith.constant 5 : i3 // %1 is 0b101 +%2 = arith.extsi %1 : i3 to i6 // %2 is 0b111101 +%3 = arith.constant 2 : i3 // %3 is 0b010 +%4 = arith.extsi %3 : i3 to i6 // %4 is 0b000010 + +%5 = arith.extsi %0 : vector<2 x i32> to vector<2 x i64> +``` +""" +function extsi(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.extsi", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`extui` + +The integer zero extension operation takes an integer input of +width M and an integer destination type of width N. The destination +bit-width must be larger than the input bit-width (N > M). +The top-most (N - M) bits of the output are filled with zeros. + +# Example + +```mlir + %1 = arith.constant 5 : i3 // %1 is 0b101 + %2 = arith.extui %1 : i3 to i6 // %2 is 0b000101 + %3 = arith.constant 2 : i3 // %3 is 0b010 + %4 = arith.extui %3 : i3 to i6 // %4 is 0b000010 + + %5 = arith.extui %0 : vector<2 x i32> to vector<2 x i64> +``` +""" +function extui(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.extui", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`fptosi` + +Cast from a value interpreted as floating-point to the nearest (rounding +towards zero) signed integer value. When operating on vectors, casts +elementwise. +""" +function fptosi(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.fptosi", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`fptoui` + +Cast from a value interpreted as floating-point to the nearest (rounding +towards zero) unsigned integer value. When operating on vectors, casts +elementwise. +""" +function fptoui(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.fptoui", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`floordivsi` + +Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. + +Divison by zero, or signed division overflow (minimum value divided by -1) +is undefined behavior. When applied to `vector` and `tensor` values, the +behavior is undefined if _any_ of its elements are divided by zero or has a +signed division overflow. + +# Example + +```mlir +// Scalar signed integer division. +%a = arith.floordivsi %b, %c : i64 + +``` +""" +function floordivsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.floordivsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`index_cast` + +Casts between scalar or vector integers and corresponding \'index\' scalar or +vectors. Index is an integer of platform-specific bit width. If casting to +a wider integer, the value is sign-extended. If casting to a narrower +integer, the value is truncated. +""" +function index_cast(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.index_cast", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`index_castui` + +Casts between scalar or vector integers and corresponding \'index\' scalar or +vectors. Index is an integer of platform-specific bit width. If casting to +a wider integer, the value is zero-extended. If casting to a narrower +integer, the value is truncated. +""" +function index_castui(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.index_castui", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`maxnumf` + +Returns the maximum of the two arguments. +If the arguments are -0.0 and +0.0, then the result is either of them. +If one of the arguments is NaN, then the result is the other argument. + +# Example + +```mlir +// Scalar floating-point maximum. +%a = arith.maxnumf %b, %c : f64 +``` +""" +function maxnumf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.maxnumf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + + +function maxsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.maxsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + + +function maxui(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.maxui", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`maximumf` + +Returns the maximum of the two arguments, treating -0.0 as less than +0.0. +If one of the arguments is NaN, then the result is also NaN. + +# Example + +```mlir +// Scalar floating-point maximum. +%a = arith.maximumf %b, %c : f64 +``` +""" +function maximumf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.maximumf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`minnumf` + +Returns the minimum of the two arguments. +If the arguments are -0.0 and +0.0, then the result is either of them. +If one of the arguments is NaN, then the result is the other argument. + +# Example + +```mlir +// Scalar floating-point minimum. +%a = arith.minnumf %b, %c : f64 +``` +""" +function minnumf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.minnumf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + + +function minsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.minsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + + +function minui(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.minui", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`minimumf` + +Returns the minimum of the two arguments, treating -0.0 as less than +0.0. +If one of the arguments is NaN, then the result is also NaN. + +# Example + +```mlir +// Scalar floating-point minimum. +%a = arith.minimumf %b, %c : f64 +``` +""" +function minimumf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.minimumf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`mulf` + +The `mulf` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be a floating point +scalar type, a vector whose element type is a floating point type, or a +floating point tensor. + +# Example + +```mlir +// Scalar multiplication. +%a = arith.mulf %b, %c : f64 + +// SIMD pointwise vector multiplication, e.g. for Intel SSE. +%f = arith.mulf %g, %h : vector<4xf32> + +// Tensor pointwise multiplication. +%x = arith.mulf %y, %z : tensor<4x?xbf16> +``` + +TODO: In the distant future, this will accept optional attributes for fast +math, contraction, rounding mode, and other controls. +""" +function mulf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.mulf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`muli` + +Performs N-bit multiplication on the operands. The operands are interpreted as +unsigned bitvectors. The result is represented by a bitvector containing the +mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth. +Because `arith` integers use a two\'s complement representation, this operation is +applicable on both signed and unsigned integer operands. + +The `muli` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be an integer scalar type, +a vector whose element type is integer, or a tensor of integers. + +This op supports `nuw`/`nsw` overflow flags which stands stand for +\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or +`nsw` flags are present, and an unsigned/signed overflow occurs +(respectively), the result is poison. + +# Example + +```mlir +// Scalar multiplication. +%a = arith.muli %b, %c : i64 + +// Scalar multiplication with overflow flags. +%a = arith.muli %b, %c overflow : i64 + +// SIMD vector element-wise multiplication. +%f = arith.muli %g, %h : vector<4xi32> + +// Tensor element-wise multiplication. +%x = arith.muli %y, %z : tensor<4x?xi8> +``` +""" +function muli(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, overflowFlags=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(overflowFlags) && push!(attributes, namedattribute("overflowFlags", overflowFlags)) + + create_operation( + "arith.muli", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`mulsi_extended` + +Performs (2*N)-bit multiplication on sign-extended operands. Returns two +N-bit results: the low and the high halves of the product. The low half has +the same value as the result of regular multiplication `arith.muli` with +the same operands. + +# Example + +```mlir +// Scalar multiplication. +%low, %high = arith.mulsi_extended %a, %b : i32 + +// Vector element-wise multiplication. +%c:2 = arith.mulsi_extended %d, %e : vector<4xi32> + +// Tensor element-wise multiplication. +%x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi8> +``` +""" +function mulsi_extended(lhs::Value, rhs::Value; low=nothing::Union{Nothing, IR.Type}, high=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(low) && push!(op_ty_results, low) + !isnothing(high) && push!(op_ty_results, high) + + create_operation( + "arith.mulsi_extended", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`mului_extended` + +Performs (2*N)-bit multiplication on zero-extended operands. Returns two +N-bit results: the low and the high halves of the product. The low half has +the same value as the result of regular multiplication `arith.muli` with +the same operands. + +# Example + +```mlir +// Scalar multiplication. +%low, %high = arith.mului_extended %a, %b : i32 + +// Vector element-wise multiplication. +%c:2 = arith.mului_extended %d, %e : vector<4xi32> + +// Tensor element-wise multiplication. +%x:2 = arith.mului_extended %y, %z : tensor<4x?xi8> +``` +""" +function mului_extended(lhs::Value, rhs::Value; low=nothing::Union{Nothing, IR.Type}, high=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(low) && push!(op_ty_results, low) + !isnothing(high) && push!(op_ty_results, high) + + create_operation( + "arith.mului_extended", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`negf` + +The `negf` operation computes the negation of a given value. It takes one +operand and returns one result of the same type. This type may be a float +scalar type, a vector whose element type is float, or a tensor of floats. +It has no standard attributes. + +# Example + +```mlir +// Scalar negation value. +%a = arith.negf %b : f64 + +// SIMD vector element-wise negation value. +%f = arith.negf %g : vector<4xf32> + +// Tensor element-wise negation value. +%x = arith.negf %y : tensor<4x?xf8> +``` +""" +function negf(operand::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.negf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`ori` + +The `ori` operation takes two operands and returns one result, each of these +is required to be the same type. This type may be an integer scalar type, a +vector whose element type is integer, or a tensor of integers. It has no +standard attributes. + +# Example + +```mlir +// Scalar integer bitwise or. +%a = arith.ori %b, %c : i64 + +// SIMD vector element-wise bitwise integer or. +%f = arith.ori %g, %h : vector<4xi32> + +// Tensor element-wise bitwise integer or. +%x = arith.ori %y, %z : tensor<4x?xi8> +``` +""" +function ori(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.ori", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`remf` + +Returns the floating point division remainder. +The remainder has the same sign as the dividend (lhs operand). +""" +function remf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.remf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`remsi` + +Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % +-2 = 0`. + +Division by zero is undefined behavior. When applied to `vector` and +`tensor` values, the behavior is undefined if _any_ elements are divided by +zero. + +# Example + +```mlir +// Scalar signed integer division remainder. +%a = arith.remsi %b, %c : i64 + +// SIMD vector element-wise division remainder. +%f = arith.remsi %g, %h : vector<4xi32> + +// Tensor element-wise integer division remainder. +%x = arith.remsi %y, %z : tensor<4x?xi8> +``` +""" +function remsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.remsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`remui` + +Unsigned integer division remainder. Treats the leading bit as the most +significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. + +Division by zero is undefined behavior. When applied to `vector` and +`tensor` values, the behavior is undefined if _any_ elements are divided by +zero. + +# Example + +```mlir +// Scalar unsigned integer division remainder. +%a = arith.remui %b, %c : i64 + +// SIMD vector element-wise division remainder. +%f = arith.remui %g, %h : vector<4xi32> + +// Tensor element-wise integer division remainder. +%x = arith.remui %y, %z : tensor<4x?xi8> +``` +""" +function remui(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.remui", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`sitofp` + +Cast from a value interpreted as a signed integer to the corresponding +floating-point value. If the value cannot be exactly represented, it is +rounded using the default rounding mode. When operating on vectors, casts +elementwise. +""" +function sitofp(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.sitofp", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`shli` + +The `shli` operation shifts the integer value of the first operand to the left +by the integer value of the second operand. The second operand is interpreted as +unsigned. The low order bits are filled with zeros. If the value of the second +operand is greater or equal than the bitwidth of the first operand, then the +operation returns poison. + +This op supports `nuw`/`nsw` overflow flags which stands stand for +\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or +`nsw` flags are present, and an unsigned/signed overflow occurs +(respectively), the result is poison. + +# Example + +```mlir +%1 = arith.constant 5 : i8 // %1 is 0b00000101 +%2 = arith.constant 3 : i8 +%3 = arith.shli %1, %2 : i8 // %3 is 0b00101000 +%4 = arith.shli %1, %2 overflow : i8 +``` +""" +function shli(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, overflowFlags=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(overflowFlags) && push!(attributes, namedattribute("overflowFlags", overflowFlags)) + + create_operation( + "arith.shli", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`shrsi` + +The `shrsi` operation shifts an integer value of the first operand to the right +by the value of the second operand. The first operand is interpreted as signed, +and the second operand is interpreter as unsigned. The high order bits in the +output are filled with copies of the most-significant bit of the shifted value +(which means that the sign of the value is preserved). If the value of the second +operand is greater or equal than bitwidth of the first operand, then the operation +returns poison. + +# Example + +```mlir +%1 = arith.constant 160 : i8 // %1 is 0b10100000 +%2 = arith.constant 3 : i8 +%3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 +%4 = arith.constant 96 : i8 // %4 is 0b01100000 +%5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 +``` +""" +function shrsi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.shrsi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`shrui` + +The `shrui` operation shifts an integer value of the first operand to the right +by the value of the second operand. The first operand is interpreted as unsigned, +and the second operand is interpreted as unsigned. The high order bits are always +filled with zeros. If the value of the second operand is greater or equal than the +bitwidth of the first operand, then the operation returns poison. + +# Example + +```mlir +%1 = arith.constant 160 : i8 // %1 is 0b10100000 +%2 = arith.constant 3 : i8 +%3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 +``` +""" +function shrui(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.shrui", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`subf` + +The `subf` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be a floating point +scalar type, a vector whose element type is a floating point type, or a +floating point tensor. + +# Example + +```mlir +// Scalar subtraction. +%a = arith.subf %b, %c : f64 + +// SIMD vector subtraction, e.g. for Intel SSE. +%f = arith.subf %g, %h : vector<4xf32> + +// Tensor subtraction. +%x = arith.subf %y, %z : tensor<4x?xbf16> +``` + +TODO: In the distant future, this will accept optional attributes for fast +math, contraction, rounding mode, and other controls. +""" +function subf(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.subf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`subi` + +Performs N-bit subtraction on the operands. The operands are interpreted as unsigned +bitvectors. The result is represented by a bitvector containing the mathematical +value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith` +integers use a two\'s complement representation, this operation is applicable on +both signed and unsigned integer operands. + +The `subi` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be an integer scalar type, +a vector whose element type is integer, or a tensor of integers. + +This op supports `nuw`/`nsw` overflow flags which stands stand for +\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or +`nsw` flags are present, and an unsigned/signed overflow occurs +(respectively), the result is poison. + +# Example + +```mlir +// Scalar subtraction. +%a = arith.subi %b, %c : i64 + +// Scalar subtraction with overflow flags. +%a = arith.subi %b, %c overflow : i64 + +// SIMD vector element-wise subtraction. +%f = arith.subi %g, %h : vector<4xi32> + +// Tensor element-wise subtraction. +%x = arith.subi %y, %z : tensor<4x?xi8> +``` +""" +function subi(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, overflowFlags=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(overflowFlags) && push!(attributes, namedattribute("overflowFlags", overflowFlags)) + + create_operation( + "arith.subi", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`truncf` + +Truncate a floating-point value to a smaller floating-point-typed value. +The destination type must be strictly narrower than the source type. +If the value cannot be exactly represented, it is rounded using the +provided rounding mode or the default one if no rounding mode is provided. +When operating on vectors, casts elementwise. +""" +function truncf(in::Value; out::IR.Type, roundingmode=nothing, fastmath=nothing, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(roundingmode) && push!(attributes, namedattribute("roundingmode", roundingmode)) + !isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath)) + + create_operation( + "arith.truncf", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`trunci` + +The integer truncation operation takes an integer input of +width M and an integer destination type of width N. The destination +bit-width must be smaller than the input bit-width (N < M). +The top-most (N - M) bits of the input are discarded. + +# Example + +```mlir + %1 = arith.constant 21 : i5 // %1 is 0b10101 + %2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101 + %3 = arith.trunci %1 : i5 to i3 // %3 is 0b101 + + %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> +``` +""" +function trunci(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.trunci", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`uitofp` + +Cast from a value interpreted as unsigned integer to the corresponding +floating-point value. If the value cannot be exactly represented, it is +rounded using the default rounding mode. When operating on vectors, casts +elementwise. +""" +function uitofp(in::Value; out::IR.Type, location=Location()) + op_ty_results = IR.Type[out, ] + operands = Value[in, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "arith.uitofp", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`xori` + +The `xori` operation takes two operands and returns one result, each of +these is required to be the same type. This type may be an integer scalar +type, a vector whose element type is integer, or a tensor of integers. It +has no standard attributes. + +# Example + +```mlir +// Scalar integer bitwise xor. +%a = arith.xori %b, %c : i64 + +// SIMD vector element-wise bitwise integer xor. +%f = arith.xori %g, %h : vector<4xi32> + +// Tensor element-wise bitwise integer xor. +%x = arith.xori %y, %z : tensor<4x?xi8> +``` +""" +function xori(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.xori", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`select` + +The `arith.select` operation chooses one value based on a binary condition +supplied as its first operand. + +If the value of the first operand (the condition) is `1`, then the second +operand is returned, and the third operand is ignored, even if it was poison. + +If the value of the first operand (the condition) is `0`, then the third +operand is returned, and the second operand is ignored, even if it was poison. + +If the value of the first operand (the condition) is poison, then the +operation returns poison. + +The operation applies to vectors and tensors elementwise given the _shape_ +of all operands is identical. The choice is made for each element +individually based on the value at the same position as the element in the +condition operand. If an i1 is provided as the condition, the entire vector +or tensor is chosen. + +# Example + +```mlir +// Custom form of scalar selection. +%x = arith.select %cond, %true, %false : i32 + +// Generic form of the same operation. +%x = \"arith.select\"(%cond, %true, %false) : (i1, i32, i32) -> i32 + +// Element-wise vector selection. +%vx = arith.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32> + +// Full vector selection. +%vx = arith.select %cond, %vtrue, %vfalse : vector<42xf32> +``` +""" +function select(condition::Value, true_value::Value, false_value::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[condition, true_value, false_value, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "arith.select", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +end # arith diff --git a/src/mlir/Dialects/Builtin.jl b/src/mlir/Dialects/Builtin.jl new file mode 100755 index 000000000..acf5a3f04 --- /dev/null +++ b/src/mlir/Dialects/Builtin.jl @@ -0,0 +1,93 @@ +module builtin +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`module_` + +A `module` represents a top-level container operation. It contains a single +[graph region](../LangRef.md#control-flow-and-ssacfg-regions) containing a single block +which can contain any operations and does not have a terminator. Operations +within this region cannot implicitly capture values defined outside the module, +i.e. Modules are [IsolatedFromAbove](../Traits.md#isolatedfromabove). Modules have +an optional [symbol name](../SymbolsAndSymbolTables.md) which can be used to refer +to them in operations. + +# Example + +```mlir +module { + func.func @foo() +} +``` +""" +function module_(; sym_name=nothing, sym_visibility=nothing, bodyRegion::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[bodyRegion, ] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(sym_name) && push!(attributes, namedattribute("sym_name", sym_name)) + !isnothing(sym_visibility) && push!(attributes, namedattribute("sym_visibility", sym_visibility)) + + create_operation( + "builtin.module", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`unrealized_conversion_cast` + +An `unrealized_conversion_cast` operation represents an unrealized +conversion from one set of types to another, that is used to enable the +inter-mixing of different type systems. This operation should not be +attributed any special representational or execution semantics, and is +generally only intended to be used to satisfy the temporary intermixing of +type systems during the conversion of one type system to another. + +This operation may produce results of arity 1-N, and accept as input +operands of arity 0-N. + +# Example + +```mlir +// An unrealized 0-1 conversion. These types of conversions are useful in +// cases where a type is removed from the type system, but not all uses have +// been converted. For example, imagine we have a tuple type that is +// expanded to its element types. If only some uses of an empty tuple type +// instance are converted we still need an instance of the tuple type, but +// have no inputs to the unrealized conversion. +%result = unrealized_conversion_cast to !bar.tuple_type<> + +// An unrealized 1-1 conversion. +%result1 = unrealized_conversion_cast %operand : !foo.type to !bar.lowered_type + +// An unrealized 1-N conversion. +%results2:2 = unrealized_conversion_cast %tuple_operand : !foo.tuple_type to !foo.type, !foo.type + +// An unrealized N-1 conversion. +%result3 = unrealized_conversion_cast %operand, %operand : !foo.type, !foo.type to !bar.tuple_type +``` +""" +function unrealized_conversion_cast(inputs::Vector{Value}; outputs::Vector{IR.Type}, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "builtin.unrealized_conversion_cast", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +end # builtin diff --git a/src/mlir/Dialects/CHLO.jl b/src/mlir/Dialects/CHLO.jl new file mode 100755 index 000000000..0de3a871c --- /dev/null +++ b/src/mlir/Dialects/CHLO.jl @@ -0,0 +1,1158 @@ +module chlo +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`acos` + +Returns `Acos(operand)` element-wise. + +\$\$ +\\acos(x) = 2 * \\atan(\\sqrt(1 - x^2) / (1 + x)) if x != -1 + = pi if x == -1 +\$\$ +""" +function acos(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.acos", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`acosh` + +Returns `Acosh(operand)` element-wise. + +\$\$ +\\acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 +\\acosh(x) = nan if x < -1 +\$\$ +""" +function acosh(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.acosh", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`_asin_acos_kernel` + +Returns `AsinAcosKernel(operand)` element-wise. + +If + w = _asin_acos_kernel(z) + w\' = _asin_acos_kernel(I * z) +then + asin(z) = complex(atan2(z.real, w.real), sign(z.imag) * w.imag) + acos(z) = complex(atan2(w.real, z.real), -sign(z.imag) * w.imag) + asinh(z) = complex(sign(z.real) * w\'.imag, atan2(z.imag, w\'.real)) + acosh(z) = complex(w.imag, sign(z.imag) * atan2(w.real, z.real)) + +This op is used as an intermediate value in decompositions and +should never be constructed directly by frameworks or consumed by +backends. +""" +function _asin_acos_kernel(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo._asin_acos_kernel", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`asin` + +Returns `Asin(operand)` element-wise. + +\$\$ +\\asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +\$\$ +""" +function asin(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.asin", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`asinh` + +Returns `Asinh(operand)` element-wise. + +\$\$ +\\asinh(x) = log(x + sqrt(x^2 + 1)) +\$\$ +""" +function asinh(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.asinh", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`atan` + +Returns `Atan(operand)` element-wise. + +\$\$ +\\atan(x) = \\atan2(x, 1) +\$\$ +""" +function atan(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.atan", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`atanh` + +Returns `Atanh(operand)` element-wise. + +\$\$ +\\atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 + = nan otherwise +\$\$ +""" +function atanh(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.atanh", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`bessel_i1e` + +Returns `bessel_i1e(operand)` element-wise. +""" +function bessel_i1e(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.bessel_i1e", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_add` + +Returns `lhs + rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_add(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_add", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_and` + +Returns `logical_and(lhs, rhs)` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_and(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_and", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_atan2` + +Returns `atan2(lhs/rhs)` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_atan2(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_atan2", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_compare` + +Compares `lhs` and `rhs` elementwise according to `comparison_direction` +and `compare_type`. If unspecified, `compare_type` is FLOAT for float element +types, SIGNED for signed element types and UNSIGNED for unsigned element +types. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. +""" +function broadcast_compare(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, comparison_direction, compare_type=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("comparison_direction", comparison_direction), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + !isnothing(compare_type) && push!(attributes, namedattribute("compare_type", compare_type)) + + create_operation( + "chlo.broadcast_compare", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_complex` + +Performs element-wise conversion of a pair of real and imaginary values to +a complex value. +""" +function broadcast_complex(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_complex", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_divide` + +Returns `lhs / rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_divide(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_divide", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_maximum` + +Returns `max(lhs, rhs)` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_maximum(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_maximum", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_minimum` + +Returns `min(lhs, rhs)` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_minimum(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_minimum", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_multiply` + +Returns `lhs * rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_multiply(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_multiply", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_next_after` + +Returns the next representable value of `lhs` in the direction of `rhs`, +element-wise. It can also return a subnormal number. + +Equivalent to the C++ std::nextafter function. +""" +function broadcast_next_after(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_next_after", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_or` + +Returns `logical_or(lhs, rhs)` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_or(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_or", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_polygamma` + +Returns `Polygamma(operand, operand)` element-wise. +""" +function broadcast_polygamma(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_polygamma", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_power` + +Returns `lhs ^ rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_power(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_power", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_remainder` + +Returns `lhs % rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_remainder(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_remainder", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_select` + +Constructs an output array from elements of two input arrays, based on the +values of a predicate array. + +See https://www.tensorflow.org/xla/operation_semantics#select +""" +function broadcast_select(pred::Value, on_true::Value, on_false::Value; result_0=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[pred, on_true, on_false, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "chlo.broadcast_select", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_shift_left` + +Returns `lhs << rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_shift_left(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_shift_left", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_shift_right_arithmetic` + +Returns `lhs >> rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_shift_right_arithmetic(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_shift_right_arithmetic", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_shift_right_logical` + +Returns `lhs >> rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_shift_right_logical(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_shift_right_logical", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_subtract` + +Returns `lhs - rhs` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_subtract(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_subtract", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_xor` + +Returns `logical_xor(lhs, rhs)` element-wise. + +See +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. +""" +function broadcast_xor(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_xor", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`broadcast_zeta` + +Returns `Zeta(operand, operand)` element-wise. + +\$\$ +\\(\\zeta(x, q) = \\sum_{n=0}^{\\infty} (q + n)^{-x}\\) +\$\$ +""" +function broadcast_zeta(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(broadcast_dimensions) && push!(attributes, namedattribute("broadcast_dimensions", broadcast_dimensions)) + + create_operation( + "chlo.broadcast_zeta", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`conj` + +Returns `Conj(operand)` element-wise. + +\$\$ +\\conj(x) = (\\real(x), \\neg(\\imag(x))) +\$\$ +""" +function conj(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.conj", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`constant_like` + +Returns a splat constant of the same shape as the operand. +""" +function constant_like(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, value, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "chlo.constant_like", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`constant` + +Represents a constant value. +""" +function constant(; output=nothing::Union{Nothing, IR.Type}, value, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value), ] + !isnothing(output) && push!(op_ty_results, output) + + create_operation( + "chlo.constant", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`cosh` + +Returns `Cosh(operand)` element-wise. + +\$\$ +\\cosh(x) = (e^x + e^-x) / 2 +\$\$ +""" +function cosh(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.cosh", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`digamma` + +Returns `Digamma(operand)` element-wise. +""" +function digamma(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.digamma", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`erf_inv` + +Returns `ErfInv(operand)` element-wise. +""" +function erf_inv(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.erf_inv", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`erf` + +Computes the Gauss error function of `x` element-wise. + +erf(x) = erf_impl(x) if |x| < 1 + = 1 - erfc_impl(x) otherwise +""" +function erf(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.erf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`erfc` + +Computes an approximation of the error function complement (1 - erf(x)). + +erfc(x) = erfc_impl(x) if |x| > 1 + = 1 - erf_impl(x) otherwise +""" +function erfc(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.erfc", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`is_inf` + +Returns if a value is +/-inf element-wise. +""" +function is_inf(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.is_inf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`is_neg_inf` + +Returns if a value is -inf element-wise. +""" +function is_neg_inf(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.is_neg_inf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`is_pos_inf` + +Returns if a value is +inf element-wise. +""" +function is_pos_inf(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.is_pos_inf", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`lgamma` + +Returns `Lgamma(operand)` element-wise. +""" +function lgamma(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.lgamma", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`next_after` + +Returns the next representable value of `x` in the direction of `y`, +element-wise. It can also return a subnormal number. + +Equivalent to the C++ std::nextafter function. +""" +function next_after(x::Value, y::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[x, y, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.next_after", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`polygamma` + +Returns `Polygamma(operand, operand)` element-wise. +""" +function polygamma(n::Value, x::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[n, x, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.polygamma", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`sinh` + +Returns `Sinh(operand)` element-wise. + +\$\$ +\\sinh(x) = (e^x - e^-x) / 2 if |x| < 1 + = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. +\$\$ +""" +function sinh(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.sinh", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`tan` + +Returns `Tan(operand)` element-wise. + +\$\$ +\\tan(x) = \\sin(x) / \\cos(x) +\$\$ +""" +function tan(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.tan", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`top_k` + +If the input is a vector (rank-1), finds the `k` largest entries in the vector +and outputs their values and indices as vectors. Thus `values[j]` is the +`j`-th largest entry in `input`, and its index is `indices[j]`. + +For matrices (resp. higher rank input), computes the top `k` entries in each +row (resp. vector along the last dimension). Thus, + + values.shape = indices.shape = input.shape[:-1] + [k] + +If two elements are equal, the lower-index element appears first. +""" +function top_k(operand::Value; values=nothing::Union{Nothing, IR.Type}, indices=nothing::Union{Nothing, IR.Type}, k, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("k", k), ] + !isnothing(values) && push!(op_ty_results, values) + !isnothing(indices) && push!(op_ty_results, indices) + + create_operation( + "chlo.top_k", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`zeta` + +Returns `Zeta(operand, operand)` element-wise. + +\$\$ +\\(\\zeta(x, q) = \\sum_{n=0}^{\\infty} (q + n)^{-x}\\) +\$\$ +""" +function zeta(x::Value, q::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[x, q, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "chlo.zeta", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +end # chlo diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl new file mode 100755 index 000000000..0b1b758b7 --- /dev/null +++ b/src/mlir/Dialects/Enzyme.jl @@ -0,0 +1,191 @@ +module enzyme +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`addTo` + +TODO +""" +function addTo(values::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[values..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.addTo", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + + create_operation( + "enzyme.autodiff", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] + + create_operation( + "enzyme.batch", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + + create_operation( + "enzyme.fwddiff", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) + op_ty_results = IR.Type[result_tensors..., ] + operands = Value[inputs..., outputs..., ] + owned_regions = Region[region, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) + !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) + !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) + + create_operation( + "enzyme.genericAdjoint", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function get(gradient::Value; result_0::IR.Type, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[gradient, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.get", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function init(; result_0::IR.Type, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.init", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function placeholder(; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.placeholder", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function pop(cache::Value; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[cache, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.pop", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function push(cache::Value, value::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[cache, value, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.push", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function set(gradient::Value, value::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[gradient, value, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "enzyme.set", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +end # enzyme diff --git a/src/mlir/Dialects/Func.jl b/src/mlir/Dialects/Func.jl new file mode 100755 index 000000000..1549d8315 --- /dev/null +++ b/src/mlir/Dialects/Func.jl @@ -0,0 +1,194 @@ +module func +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`call_indirect` + +The `func.call_indirect` operation represents an indirect call to a value +of function type. The operands and result types of the call must match the +specified function type. + +Function values can be created with the +[`func.constant` operation](#funcconstant-constantop). + +# Example + +```mlir +%func = func.constant @my_func : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> +%result = func.call_indirect %func(%0, %1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> +``` +""" +function call_indirect(callee::Value, callee_operands::Vector{Value}; results::Vector{IR.Type}, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[callee, callee_operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "func.call_indirect", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`call` + +The `func.call` operation represents a direct call to a function that is +within the same symbol scope as the call. The operands and result types of +the call must match the specified function type. The callee is encoded as a +symbol reference attribute named \"callee\". + +# Example + +```mlir +%2 = func.call @my_add(%0, %1) : (f32, f32) -> f32 +``` +""" +function call(operands::Vector{Value}; result_0::Vector{IR.Type}, callee, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("callee", callee), ] + + create_operation( + "func.call", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`constant` + +The `func.constant` operation produces an SSA value from a symbol reference +to a `func.func` operation + +# Example + +```mlir +// Reference to function @myfn. +%2 = func.constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> + +// Equivalent generic forms +%2 = \"func.constant\"() { value = @myfn } : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) +``` + +MLIR does not allow direct references to functions in SSA operands because +the compiler is multithreaded, and disallowing SSA values to directly +reference a function simplifies this +([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). +""" +function constant(; result_0::IR.Type, value, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value), ] + + create_operation( + "func.constant", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`func_` + +Operations within the function cannot implicitly capture values defined +outside of the function, i.e. Functions are `IsolatedFromAbove`. All +external references must use function arguments or attributes that establish +a symbolic connection (e.g. symbols referenced by name via a string +attribute like SymbolRefAttr). An external function declaration (used when +referring to a function declared in some other module) has no body. While +the MLIR textual form provides a nice inline syntax for function arguments, +they are internally represented as “block arguments” to the first block in +the region. + +Only dialect attribute names may be specified in the attribute dictionaries +for function arguments, results, or the function itself. + +# Example + +```mlir +// External function definitions. +func.func private @abort() +func.func private @scribble(i32, i64, memref) -> f64 + +// A function that returns its argument twice: +func.func @count(%x: i64) -> (i64, i64) + attributes {fruit: \"banana\"} { + return %x, %x: i64, i64 +} + +// A function with an argument attribute +func.func private @example_fn_arg(%x: i32 {swift.self = unit}) + +// A function with a result attribute +func.func private @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + +// A function with an attribute +func.func private @example_fn_attr() attributes {dialectName.attrName = false} +``` +""" +function func_(; sym_name, function_type, sym_visibility=nothing, arg_attrs=nothing, res_attrs=nothing, body::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[body, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name), namedattribute("function_type", function_type), ] + !isnothing(sym_visibility) && push!(attributes, namedattribute("sym_visibility", sym_visibility)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + + create_operation( + "func.func", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`return_` + +The `func.return` operation represents a return operation within a function. +The operation takes variable number of operands and produces no results. +The operand number and types must match the signature of the function +that contains the operation. + +# Example + +```mlir +func.func @foo() : (i32, f8) { + ... + return %0, %1 : i32, f8 +} +``` +""" +function return_(operands::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "func.return", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +end # func diff --git a/src/mlir/Dialects/StableHLO.jl b/src/mlir/Dialects/StableHLO.jl new file mode 100755 index 000000000..98f49bebb --- /dev/null +++ b/src/mlir/Dialects/StableHLO.jl @@ -0,0 +1,3720 @@ +module stablehlo +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + +""" +`abs` + +Performs element-wise abs operation on `operand` tensor and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#abs + +# Example +```mlir +%result = stablehlo.abs %operand : tensor<3xi32> +``` +""" +function abs(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.abs", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`add` + +Performs element-wise addition of two tensors `lhs` and `rhs` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#add + +# Example +```mlir +%result = stablehlo.add %lhs, %rhs : tensor<2x2xi32> +``` +""" +function add(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.add", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`after_all` + +Ensures that the operations producing the `inputs` are executed before any +operations that depend on `result`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all + +# Example +```mlir +%result = stablehlo.after_all %input0, %input1 : !stablehlo.token +``` +""" +function after_all(inputs::Vector{Value}; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.after_all", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`all_gather` + +Within each process group in the process grid, concatenates the values of the +`operand` tensor from each process along `all_gather_dim` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather + +# Example +```mlir +%result:2 = \"stablehlo.all_gather\"(%operand0, %operand1) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle +} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>) +``` +""" +function all_gather(operands::Vector{Value}; result_0::Vector{IR.Type}, all_gather_dim, replica_groups, channel_handle=nothing, use_global_device_ids=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("all_gather_dim", all_gather_dim), namedattribute("replica_groups", replica_groups), ] + !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) + !isnothing(use_global_device_ids) && push!(attributes, namedattribute("use_global_device_ids", use_global_device_ids)) + + create_operation( + "stablehlo.all_gather", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`all_reduce` + +Within each process group in the process grid, applies a reduction function +`computation` to the values of the `operand` tensor from each process and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce + +# Example +```mlir +%result:2 = \"stablehlo.all_reduce\"(%operand0, %operand0) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = \"stablehlo.add\"(%arg0, %arg1) : (tensor, tensor) -> tensor + \"stablehlo.return\"(%0) : (tensor) -> () +}) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle +} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>) +``` +""" +function all_reduce(operands::Vector{Value}; result_0::Vector{IR.Type}, replica_groups, channel_handle=nothing, use_global_device_ids=nothing, computation::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[operands..., ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), ] + !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) + !isnothing(use_global_device_ids) && push!(attributes, namedattribute("use_global_device_ids", use_global_device_ids)) + + create_operation( + "stablehlo.all_reduce", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`all_to_all` + +Within each process group in the process grid, splits the values of the +`operand` tensor along `split_dimension` into parts, scatters the split parts +between the processes, concatenates the scattered parts along `concat_dimension` +and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_to_all + +# Example +```mlir +%result:2 = \"stablehlo.all_to_all\"(%operand1, %operand2) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 2 : i64, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> +} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) +``` +""" +function all_to_all(operands::Vector{Value}; result_0=nothing::Union{Nothing, Vector{IR.Type}}, split_dimension, concat_dimension, split_count, replica_groups, channel_handle=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("split_dimension", split_dimension), namedattribute("concat_dimension", concat_dimension), namedattribute("split_count", split_count), namedattribute("replica_groups", replica_groups), ] + !isnothing(result_0) && push!(op_ty_results, result_0...) + !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) + + create_operation( + "stablehlo.all_to_all", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`and` + +Performs element-wise AND of two tensors `lhs` and `rhs` and produces a +`result` tensor + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#and + +# Example +```mlir +%result = stablehlo.and %lhs, %rhs : tensor<2x2xi32> +``` +""" +function and(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.and", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`atan2` + +Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#atan2 + +# Example +```mlir +%result = stablehlo.atan2 %lhs, %rhs : tensor<3xf64> +``` +""" +function atan2(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.atan2", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`batch_norm_grad` + +Computes gradients of several inputs of BatchNormTrainingOp backpropagating +from `grad_output`, and produces `grad_operand`, `grad_scale` and +`grad_offset` tensors. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_grad + +# Example +```mlir +%grad_operand, %grad_scale, %grad_offset = +\"stablehlo.batch_norm_grad\"(%operand, %scale, %mean, %variance, %grad_output) { + epsilon = 0.0 : f32, + feature_index = 2 : i64 +} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, + tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) +``` +""" +function batch_norm_grad(operand::Value, scale::Value, mean::Value, variance::Value, grad_output::Value; grad_operand=nothing::Union{Nothing, IR.Type}, grad_scale=nothing::Union{Nothing, IR.Type}, grad_offset=nothing::Union{Nothing, IR.Type}, epsilon, feature_index, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, scale, mean, variance, grad_output, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("epsilon", epsilon), namedattribute("feature_index", feature_index), ] + !isnothing(grad_operand) && push!(op_ty_results, grad_operand) + !isnothing(grad_scale) && push!(op_ty_results, grad_scale) + !isnothing(grad_offset) && push!(op_ty_results, grad_offset) + + create_operation( + "stablehlo.batch_norm_grad", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`batch_norm_inference` + +Normalizes the `operand` tensor across all dimensions except for the +`feature_index` dimension and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_inference + +# Example +```mlir +%result = \"stablehlo.batch_norm_inference\"(%operand, %scale, %offset, %mean, %variance) { + epsilon = 0.0 : f32, + feature_index = 2 : i64 +} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64> +``` +""" +function batch_norm_inference(operand::Value, scale::Value, offset::Value, mean::Value, variance::Value; result=nothing::Union{Nothing, IR.Type}, epsilon, feature_index, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, scale, offset, mean, variance, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("epsilon", epsilon), namedattribute("feature_index", feature_index), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.batch_norm_inference", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`batch_norm_training` + +Computes mean and variance across batch and spatial dimensions and +normalizes the `operand` tensor, for each feature in the `feature_index` +dimension and produces `output`, `batch_mean` and `batch_var` tensors. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_training + +# Example +```mlir +%output, %batch_mean, %batch_var = \"stablehlo.batch_norm_training\"(%operand, %scale, %offset) { + epsilon = 0.0 : f32, + feature_index = 2 : i64 +} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) -> + (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) +``` +""" +function batch_norm_training(operand::Value, scale::Value, offset::Value; output=nothing::Union{Nothing, IR.Type}, batch_mean=nothing::Union{Nothing, IR.Type}, batch_var=nothing::Union{Nothing, IR.Type}, epsilon, feature_index, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, scale, offset, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("epsilon", epsilon), namedattribute("feature_index", feature_index), ] + !isnothing(output) && push!(op_ty_results, output) + !isnothing(batch_mean) && push!(op_ty_results, batch_mean) + !isnothing(batch_var) && push!(op_ty_results, batch_var) + + create_operation( + "stablehlo.batch_norm_training", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`bitcast_convert` + +Performs a bitcast operation on `operand` tensor and produces a `result` +tensor where the bits of the entire `operand` tensor are reinterpreted using +the type of the `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#bitcast_convert + +# Example +```mlir +%result = stablehlo.bitcast_convert %operand : (tensor) -> tensor<4xf16> +``` +""" +function bitcast_convert(operand::Value; result_0::IR.Type, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.bitcast_convert", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`broadcast_in_dim` + +Expands the dimensions and/or rank of an input tensor by duplicating the +data in the `operand` tensor and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim + +# Example +```mlir +%result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> +``` +""" +function broadcast_in_dim(operand::Value; result_0::IR.Type, broadcast_dimensions, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("broadcast_dimensions", broadcast_dimensions), ] + + create_operation( + "stablehlo.broadcast_in_dim", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`broadcast` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as XLA\'s Broadcast: +https://www.tensorflow.org/xla/operation_semantics#broadcast + +# Example +```mlir +%result = stablehlo.broadcast %operand, sizes = [1, 2] : (tensor<3xi32>) -> tensor<1x2x3xi32> +``` +""" +function broadcast(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, broadcast_sizes, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("broadcast_sizes", broadcast_sizes), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.broadcast", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`case` + +Produces the output from executing exactly one `function` from `branches` +depending on the value of `index`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + +# Example +```mlir +%result0, %result1 = \"stablehlo.case\"(%index) ({ + stablehlo.return %result_branch0, %result_branch0 : tensor<2xi64>, tensor<2xi64> +}, { + stablehlo.return %result_branch1, %result_branch1 : tensor<2xi64>, tensor<2xi64> +}) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) +``` +""" +function case(index::Value; result_0::Vector{IR.Type}, branches::Vector{Region}, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[index, ] + owned_regions = Region[branches..., ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.case", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`cbrt` + +Performs element-wise cubic root operation on `operand` tensor and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cbrt + +# Example +```mlir +%result = stablehlo.cbrt %operand : tensor<4xf64> +``` +""" +function cbrt(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.cbrt", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`ceil` + +Performs element-wise ceil of `operand` tensor and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#ceil + +# Example +```mlir +%result = stablehlo.ceil %operand : tensor<5xf32> +``` +""" +function ceil(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.ceil", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`cholesky` + +Computes the Cholesky decomposition of a batch of matrices. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cholesky + +# Example +```mlir +%result = stablehlo.cholesky %a, lower = true : tensor<3x3xf64> +``` +""" +function cholesky(a::Value; result=nothing::Union{Nothing, IR.Type}, lower=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[a, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(lower) && push!(attributes, namedattribute("lower", lower)) + + create_operation( + "stablehlo.cholesky", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`clamp` + +Clamps every element of the `operand` tensor between a minimum and maximum +value and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp + +# Example +```mlir +%result = stablehlo.clamp %min, %operand, %max : tensor<3xi32> +``` +""" +function clamp(min::Value, operand::Value, max::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[min, operand, max, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.clamp", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`count_leading_zeros` + +Performs element-wise count of the number of leading zero bits in the +`operand` tensor and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros + +# Example +```mlir +%result = stablehlo.count_leading_zeros %operand : tensor<2x2xi64> +``` +""" +function count_leading_zeros(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.count_leading_zeros", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`collective_broadcast` + +Within each process group in the process grid, send the value of the +`operand` tensor from the source process to the target processes and produce a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast + +# Example +```mlir +%result = \"stablehlo.collective_broadcast\"(%operand) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle +} : (tensor<1x2xi64>) -> tensor<1x2xi64> +``` +""" +function collective_broadcast(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, replica_groups, channel_handle=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) + + create_operation( + "stablehlo.collective_broadcast", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`collective_permute` + +Within each process group in the process grid, sends the value of the +`operand` tensor from the source process to the target process and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_permute + +# Example +```mlir +%result = \"stablehlo.collective_permute\"(%operand) { + source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, + channel_handle = #stablehlo.channel_handle +} : (tensor<2x2xi64>) -> tensor<2x2xi64> +``` +""" +function collective_permute(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, source_target_pairs, channel_handle=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("source_target_pairs", source_target_pairs), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) + + create_operation( + "stablehlo.collective_permute", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`compare` + +Performs element-wise comparison of `lhs` and `rhs` tensors according to +`comparison_direction` and `compare_type`, and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#compare + +# Example +```mlir +%result = stablehlo.compare LT, %lhs, %rhs, FLOAT : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +``` +""" +function compare(lhs::Value, rhs::Value; result_0=nothing::Union{Nothing, IR.Type}, comparison_direction, compare_type=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("comparison_direction", comparison_direction), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(compare_type) && push!(attributes, namedattribute("compare_type", compare_type)) + + create_operation( + "stablehlo.compare", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`complex` + +Performs element-wise conversion to a complex value from a pair of real and +imaginary values, `lhs` and `rhs`, and produces a `result` tensor. +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex +# Example +```mlir +%result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex> +``` +""" +function complex(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.complex", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`composite` + +Encapsulates an operation made up (composed) of other StableHLO operations, +taking `inputs` and `composite_attributes` and producing `results`. The +semantics of the op are implemented by the `decomposition` attribute. The +`composite` op can be replaced with its decomposition without changing program +semantics. In cases where inlining the decomposition does not provide the same +op semantics, prefer using `custom_call`. + +The `version` field (defaults to `0`) is used to denote when a composite\'s +semantics change. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite + +# Example +```mlir +%results = stablehlo.composite \"my.op\" %input0, %input1 { + composite_attributes = { + my_attribute = \"my_value\" + }, + decomposition = @my_op, + version = 1 : i32 +} : (tensor, tensor) -> tensor +``` +""" +function composite(inputs::Vector{Value}; result_0::Vector{IR.Type}, name, composite_attributes=nothing, decomposition, version=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("name", name), namedattribute("decomposition", decomposition), ] + !isnothing(composite_attributes) && push!(attributes, namedattribute("composite_attributes", composite_attributes)) + !isnothing(version) && push!(attributes, namedattribute("version", version)) + + create_operation( + "stablehlo.composite", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`concatenate` + +Concatenates a variadic number of tensors in `inputs` along `dimension` +dimension in the same order as the given arguments and produces a `result` +tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate + +# Example +```mlir +%result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> +``` +""" +function concatenate(inputs::Vector{Value}; result_0=nothing::Union{Nothing, IR.Type}, dimension, location=Location()) + op_ty_results = IR.Type[] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.concatenate", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`constant` + +Produces an `output` tensor from a constant `value`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant + +# Example +```mlir +%output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> +``` +""" +function constant(; output=nothing::Union{Nothing, IR.Type}, value, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value), ] + !isnothing(output) && push!(op_ty_results, output) + + create_operation( + "stablehlo.constant", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`convert` + +Performs an element-wise conversion from one element type to another on +`operand` tensor and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert + +# Example +```mlir +%result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex> +``` +""" +function convert(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.convert", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`convolution` + +Computes dot products between windows of `lhs` and slices of `rhs` and +produces `result`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution + +# Example +```mlir +%result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + pad = [[0, 0], [0, 0]], + lhs_dilate = [2, 2], + rhs_dilate = [1, 1], + reverse = [0, 0] + } { + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : +(tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> +``` +""" +function convolution(lhs::Value, rhs::Value; result_0::IR.Type, window_strides=nothing, padding=nothing, lhs_dilation=nothing, rhs_dilation=nothing, window_reversal=nothing, dimension_numbers, feature_group_count, batch_group_count, precision_config=nothing, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension_numbers", dimension_numbers), namedattribute("feature_group_count", feature_group_count), namedattribute("batch_group_count", batch_group_count), ] + !isnothing(window_strides) && push!(attributes, namedattribute("window_strides", window_strides)) + !isnothing(padding) && push!(attributes, namedattribute("padding", padding)) + !isnothing(lhs_dilation) && push!(attributes, namedattribute("lhs_dilation", lhs_dilation)) + !isnothing(rhs_dilation) && push!(attributes, namedattribute("rhs_dilation", rhs_dilation)) + !isnothing(window_reversal) && push!(attributes, namedattribute("window_reversal", window_reversal)) + !isnothing(precision_config) && push!(attributes, namedattribute("precision_config", precision_config)) + + create_operation( + "stablehlo.convolution", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`cosine` + +Performs element-wise cosine operation on `operand` tensor and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine + +# Example +```mlir +%result = stablehlo.cosine %operand : tensor<2xf32> +``` +""" +function cosine(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.cosine", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`create_token` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as AfterAllOp with 0 inputs: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all + +# Example +```mlir +%output = stablehlo.create_token : !stablehlo.token +``` +""" +function create_token(; output=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(output) && push!(op_ty_results, output) + + create_operation( + "stablehlo.create_token", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`cross_replica_sum` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as AllReduceOp with +`channel_id = 0`, `use_global_device_ids = false` and `computation` +implementing addition: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce + +# Example +```mlir +%result = \"stablehlo.cross-replica-sum\"(%operand) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> +} : (tensor<4xf32>) -> tensor<4xf32> +``` +""" +function cross_replica_sum(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, replica_groups, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.cross-replica-sum", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`custom_call` + +Encapsulates an implementation-defined operation `call_target_name` that +takes `inputs` and `called_computations` and produces `results`. + +Depending on the API version there are two ways to pass extra bits of static +information to the external function: +1. Use `API_VERSION_TYPED_FFI` which allows passing a dictionary attribute. +2. Use a previous API version with a StringAttr to encode backend config. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call + +# Example +```mlir +%results = stablehlo.custom_call @foo(%input0) { + backend_config = {bar = 42 : i32}, + api_version = 4 : i32, + called_computations = [@foo] +} : (tensor) -> tensor +``` +""" +function custom_call(inputs::Vector{Value}; result_0::Vector{IR.Type}, call_target_name, has_side_effect=nothing, backend_config=nothing, api_version=nothing, called_computations=nothing, operand_layouts=nothing, result_layouts=nothing, output_operand_aliases=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("call_target_name", call_target_name), ] + !isnothing(has_side_effect) && push!(attributes, namedattribute("has_side_effect", has_side_effect)) + !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(api_version) && push!(attributes, namedattribute("api_version", api_version)) + !isnothing(called_computations) && push!(attributes, namedattribute("called_computations", called_computations)) + !isnothing(operand_layouts) && push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + + create_operation( + "stablehlo.custom_call", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`divide` + +Performs element-wise division of dividend `lhs` and divisor `rhs` tensors +and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide + +# Example +```mlir +%result = stablehlo.divide %lhs, %rhs : tensor<4xf32> +``` +""" +function divide(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.divide", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`dot_general` + +Computes dot products between slices of `lhs` and slices of `rhs` and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general + +# Example +```mlir +%result = stablehlo.dot_general %lhs, %rhs, + batching_dims = [0] x [0], + contracting_dims = [2] x [1], + precision = [DEFAULT, DEFAULT], + algorithm = + : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64> +``` +""" +function dot_general(lhs::Value, rhs::Value; result_0::IR.Type, dot_dimension_numbers, precision_config=nothing, algorithm=nothing, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dot_dimension_numbers", dot_dimension_numbers), ] + !isnothing(precision_config) && push!(attributes, namedattribute("precision_config", precision_config)) + !isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm)) + + create_operation( + "stablehlo.dot_general", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dot` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as XLA\'s Dot: +https://www.tensorflow.org/xla/operation_semantics#dot + +# Example +```mlir +%0 = stablehlo.dot %arg0, %arg1 : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<1x1xi32> +``` +""" +function dot(lhs::Value, rhs::Value; result_0::IR.Type, precision_config=nothing, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(precision_config) && push!(attributes, namedattribute("precision_config", precision_config)) + + create_operation( + "stablehlo.dot", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dynamic_broadcast_in_dim` + +This operation is functionally identical to +[broadcast_in_dim](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim) +op, but the result shape is specified dynamically via `output_dimensions`. + +It also accepts optional attributes to express static knowledge about the +expanding behavior of dimensions. If not specified, all dimensions are +assumed to be possibly expanding. The sets of dimensions that are known to +be expanding and the set of dimensions that are known to be non-expanding +must be disjoint and they must be a subset of the operand\'s dimensions. + +See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim + +# Example +```mlir +%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64> +%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64> +%result = \"stablehlo.dynamic_broadcast_in_dim\"(%operand, %output_dimensions) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array +} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> +``` +""" +function dynamic_broadcast_in_dim(operand::Value, output_dimensions::Value; result_0::IR.Type, broadcast_dimensions, known_expanding_dimensions=nothing, known_nonexpanding_dimensions=nothing, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, output_dimensions, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("broadcast_dimensions", broadcast_dimensions), ] + !isnothing(known_expanding_dimensions) && push!(attributes, namedattribute("known_expanding_dimensions", known_expanding_dimensions)) + !isnothing(known_nonexpanding_dimensions) && push!(attributes, namedattribute("known_nonexpanding_dimensions", known_nonexpanding_dimensions)) + + create_operation( + "stablehlo.dynamic_broadcast_in_dim", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dynamic_conv` + +This operation is functionally identical to +[convolution](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution) +op, but the padding is specified dynamically via `padding`. + +# Example +```mlir +%padding = stablehlo.constant dense<2> : tensor<2x2xi64> +%result = \"stablehlo.dynamic_conv\"(%lhs, %rhs, %padding) { + window_strides = array, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] +} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64> +``` +""" +function dynamic_conv(lhs::Value, rhs::Value, padding::Value; result_0::IR.Type, window_strides=nothing, lhs_dilation=nothing, rhs_dilation=nothing, window_reversal=nothing, dimension_numbers, feature_group_count, batch_group_count, precision_config=nothing, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[lhs, rhs, padding, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension_numbers", dimension_numbers), namedattribute("feature_group_count", feature_group_count), namedattribute("batch_group_count", batch_group_count), ] + !isnothing(window_strides) && push!(attributes, namedattribute("window_strides", window_strides)) + !isnothing(lhs_dilation) && push!(attributes, namedattribute("lhs_dilation", lhs_dilation)) + !isnothing(rhs_dilation) && push!(attributes, namedattribute("rhs_dilation", rhs_dilation)) + !isnothing(window_reversal) && push!(attributes, namedattribute("window_reversal", window_reversal)) + !isnothing(precision_config) && push!(attributes, namedattribute("precision_config", precision_config)) + + create_operation( + "stablehlo.dynamic_conv", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dynamic_gather` + +This operation is functionally identical to +[gather](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather) +op, with the `slice_sizes` specified dynamically as an operand. + +# Example +```mlir +%slice_sizes = stablehlo.constant dense<[1, 2, 2]> : tensor<3xi64> +%result = \"stablehlo.dynamic_gather\"(%operand, %start_indices, %slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [0, 2], + index_vector_dim = 2>, + indices_are_sorted = false +} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64> +``` +""" +function dynamic_gather(operand::Value, start_indices::Value, slice_sizes::Value; result_0=nothing::Union{Nothing, IR.Type}, dimension_numbers, indices_are_sorted=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, start_indices, slice_sizes, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension_numbers", dimension_numbers), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(indices_are_sorted) && push!(attributes, namedattribute("indices_are_sorted", indices_are_sorted)) + + create_operation( + "stablehlo.dynamic_gather", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`dynamic_iota` + +This operation is functionally identical to +[iota](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota) +op, but the result shape is specified dynamically via `output_shape`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_iota + +# Example +```mlir +%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64> +%0 = stablehlo.dynamic_iota %output_shape, dim = 0 : (tensor<2xi64>) -> tensor<4x5xi64> +``` +""" +function dynamic_iota(output_shape::Value; result::IR.Type, iota_dimension, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[output_shape, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("iota_dimension", iota_dimension), ] + + create_operation( + "stablehlo.dynamic_iota", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dynamic_pad` + +This operation is functionally identical to +[pad](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad) +https://github.com/openxla/stablehlo/pull/2306#discussion_r1595669709 +op, but with `edge_padding_low`, `edge_padding_high` and `interior_padding` +specified dynamically as values. + +See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_pad + +# Example +```mlir +%edge_padding_low = stablehlo.constant dense<[0, 1]> : tensor<2xi32> +%edge_padding_high = stablehlo.constant dense<[2, 1]> : tensor<2xi32> +%interior_padding = stablehlo.constant dense<[1, 2]> : tensor<2xi32> +%result = stablehlo.dynamic_pad %operand, %padding_value, + %edge_padding_low, %edge_padding_high, %interior_padding + : (tensor<2x3xi64>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64> +``` +""" +function dynamic_pad(operand::Value, padding_value::Value, edge_padding_low::Value, edge_padding_high::Value, interior_padding::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, padding_value, edge_padding_low, edge_padding_high, interior_padding, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.dynamic_pad", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dynamic_reshape` + +This operation is functionally identical to +[reshape](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape) +op, but the result shape is specified dynamically via `output_shape`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_reshape + +# Example +```mlir +%output_shape = stablehlo.constant dense<[3, 2]> : tensor<2xi64> +%result = stablehlo.dynamic_reshape %operand, %output_shape : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64> +``` +""" +function dynamic_reshape(operand::Value, output_shape::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, output_shape, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.dynamic_reshape", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`dynamic_slice` + +Extracts a slice from the `operand` using dynamically-computed starting +indices and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice + +# Example +```mlir +%result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2] + : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x2xi32> +``` +""" +function dynamic_slice(operand::Value, start_indices::Vector{Value}; result=nothing::Union{Nothing, IR.Type}, slice_sizes, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, start_indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("slice_sizes", slice_sizes), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.dynamic_slice", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`dynamic_update_slice` + +Produces a `result` tensor which is equal to the `operand` tensor except +that the slice starting at `start_indices` is updated with the values in +`update`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_update_slice + +# Example +```mlir +%result = stablehlo.dynamic_update_slice %operand, %update, %start_indices0, %start_indices1 + : (tensor<4x4xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<4x4xi32> +``` +""" +function dynamic_update_slice(operand::Value, update::Value, start_indices::Vector{Value}; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, update, start_indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.dynamic_update_slice", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`einsum` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as TF\'s einsum: +https://www.tensorflow.org/api_docs/python/tf/einsum + +# Example +```mlir +%result = \"stablehlo.einsum\"(%lhs, %rhs) { + einsum_config = \"ab,bc->ac\" +} : (tensor<4x16xf32>, tensor<16x4xf32>) -> tensor<4x4xf32> +``` +""" +function einsum(lhs::Value, rhs::Value; result_0::IR.Type, einsum_config, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("einsum_config", einsum_config), ] + + create_operation( + "stablehlo.einsum", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`exponential` + +Performs element-wise exponential operation on `operand` tensor and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential + +# Example +```mlir +%result = stablehlo.exponential %operand : tensor<2x2xf64> +``` +""" +function exponential(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.exponential", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`exponential_minus_one` + +Performs element-wise exponential minus one operation on `operand` tensor +and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one + +# Example +```mlir +%result = stablehlo.exponential_minus_one %operand : tensor<2xf64> +``` +""" +function exponential_minus_one(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.exponential_minus_one", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`fft` + +Performs the forward and inverse Fourier transforms for real and complex +inputs/outputs. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#fft + +# Example +```mlir +%result = stablehlo.fft %operand, type = FFT, length = [4] : (tensor<4xcomplex>) -> tensor<4xcomplex> +``` +""" +function fft(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, fft_type, fft_length, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fft_type", fft_type), namedattribute("fft_length", fft_length), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.fft", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`floor` + +Performs element-wise floor of `operand` tensor and produces a `result` +tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor + +# Example +```mlir +%result = stablehlo.floor %operand : tensor<2xf32> +``` +""" +function floor(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.floor", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`gather` + +Gathers slices from `operand` tensor from offsets specified in +`start_indices` and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + +# Example +```mlir +%result = \"stablehlo.gather\"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false +} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64> +``` +""" +function gather(operand::Value, start_indices::Value; result=nothing::Union{Nothing, IR.Type}, dimension_numbers, slice_sizes, indices_are_sorted=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, start_indices, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension_numbers", dimension_numbers), namedattribute("slice_sizes", slice_sizes), ] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(indices_are_sorted) && push!(attributes, namedattribute("indices_are_sorted", indices_are_sorted)) + + create_operation( + "stablehlo.gather", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`get_dimension_size` + +Produces the size of the given `dimension` of the `operand`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_dimension_size + +# Example +```mlir +%result = stablehlo.get_dimension_size %operand, dim = 1 : (tensor<2x3xi64>) -> tensor +``` +""" +function get_dimension_size(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, dimension, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.get_dimension_size", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`get_tuple_element` + +Extracts element at `index` position of the `operand` tuple and produces a +`result`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_tuple_element + +# Example +```mlir +%result = stablehlo.get_tuple_element %operand[0] : (tuple, tuple>>) -> tensor<2xf64> +``` +""" +function get_tuple_element(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, index, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("index", index), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.get_tuple_element", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`if_` + +Produces the output from executing exactly one branch from `true_branch` or +`false_branch` depending on the value of `pred`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if + +# Example +%result = \"stablehlo.if\"(%pred) ({ + \"stablehlo.return\"(%result_true_branch) : (tensor) -> () +}, { + \"stablehlo.return\"(%result_false_branch) : (tensor) -> () +}) : (tensor) -> tensor +""" +function if_(pred::Value; result_0::Vector{IR.Type}, true_branch::Region, false_branch::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[pred, ] + owned_regions = Region[true_branch, false_branch, ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.if", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`imag` + +Extracts the imaginary part, element-wise, from the `operand` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag + +# Example +```mlir +%result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32> +``` +""" +function imag(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.imag", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`infeed` + +Reads data from the infeed and produces `results`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#infeed + +# Example +```mlir +%results0:2 = \"stablehlo.infeed\"(%token) : + (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token) +``` +""" +function infeed(token::Value; result_0::Vector{IR.Type}, infeed_config=nothing, layout=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(infeed_config) && push!(attributes, namedattribute("infeed_config", infeed_config)) + !isnothing(layout) && push!(attributes, namedattribute("layout", layout)) + + create_operation( + "stablehlo.infeed", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`iota` + +Fills an `output` tensor with values in increasing order starting from zero +along the `iota_dimension` dimension. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota + +# Example +```mlir +%output = stablehlo.iota dim = 0 : tensor<4x5xi32> +``` +""" +function iota(; output::IR.Type, iota_dimension, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("iota_dimension", iota_dimension), ] + + create_operation( + "stablehlo.iota", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`is_finite` + +Performs element-wise check whether the value in `x` is finite (i.e. is +neither +Inf, -Inf, nor NaN) and produces a `y` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite + +# Example +```mlir +%y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1> +``` +""" +function is_finite(x::Value; y=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[x, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(y) && push!(op_ty_results, y) + + create_operation( + "stablehlo.is_finite", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`log_plus_one` + +Performs element-wise logarithm plus one operation on `operand` tensor and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one + +# Example +```mlir +%result = stablehlo.log_plus_one %operand : tensor<5xf64> +``` +""" +function log_plus_one(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.log_plus_one", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`log` + +Performs element-wise logarithm operation on `operand` tensor and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log + +# Example +```mlir +%result = stablehlo.log %operand : tensor<2x2xf64> +``` +""" +function log(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.log", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`logistic` + +Performs element-wise logistic operation on `operand` tensor and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic + +# Example +```mlir +%result = stablehlo.logistic %operand : tensor<2x2xf64> +``` +""" +function logistic(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.logistic", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`map` + +Applies a map function `computation` to `inputs` along the `dimensions` and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map + +# Example +```mlir +%result = \"stablehlo.map\"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.multiply %arg0, %arg1 : tensor + stablehlo.return %0 : tensor +}) { + dimensions = array +} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> +``` +""" +function map(inputs::Vector{Value}; result_0::IR.Type, dimensions, computation::Region, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[inputs..., ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimensions", dimensions), ] + + create_operation( + "stablehlo.map", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`maximum` + +Performs element-wise max operation on tensors `lhs` and `rhs` and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum + +# Example +```mlir +%result = stablehlo.maximum %lhs, %rhs : tensor<4xf32> +``` +""" +function maximum(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.maximum", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`minimum` + +Performs element-wise min operation on tensors `lhs` and `rhs` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum + +# Example +```mlir +%result = stablehlo.minimum %lhs, %rhs : tensor<4xf32> +``` +""" +function minimum(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.minimum", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`multiply` + +Performs element-wise product of two tensors `lhs` and `rhs` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#multiply + +# Example +```mlir +%result = stablehlo.multiply %lhs, %rhs : tensor<2xi32> +``` +""" +function multiply(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.multiply", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`negate` + +Performs element-wise negation of `operand` tensor and produces a `result` +tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate + +# Example +```mlir +%result = stablehlo.negate %operand : tensor<2x3xi32> +``` +""" +function negate(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.negate", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`not` + +Performs element-wise NOT of tensor `operand` of type integer and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not + +# Example +```mlir +%result = stablehlo.not %operand : tensor<5x3x1xi1> +``` +""" +function not(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.not", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`optimization_barrier` + +Ensures that the operations that produce the `operand` are executed before any +operations that depend on the `result` and prevents compiler transformations +from moving operations across the barrier. Other than that, the operation is +an identity, i.e. `result` = `operand`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier + +# Example +```mlir +%result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor, tensor +``` +""" +function optimization_barrier(operand::Vector{Value}; result=nothing::Union{Nothing, Vector{IR.Type}}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result...) + + create_operation( + "stablehlo.optimization_barrier", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`or` + +Performs element-wise OR of two tensors `lhs` and `rhs` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#or + +# Example +```mlir +%result = stablehlo.or %lhs, %rhs : tensor<2xi1> +``` +""" +function or(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.or", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`outfeed` + +Writes `inputs` to the outfeed and produces a `result` token. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#outfeed + +# Example +```mlir +%result = \"stablehlo.outfeed\"(%input0, %token) : + (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token +``` +""" +function outfeed(inputs::Vector{Value}, token::Value; result_0=nothing::Union{Nothing, IR.Type}, outfeed_config=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[inputs..., token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(outfeed_config) && push!(attributes, namedattribute("outfeed_config", outfeed_config)) + + create_operation( + "stablehlo.outfeed", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`pad` + +Expands `operand` by padding around the tensor as well as between the +elements of the tensor with the given `padding_value`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad + +# Example +```mlir +%0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [1, 2] + : (tensor<2x3xi32>, tensor) -> tensor<5x9xi32> +``` +""" +function pad(operand::Value, padding_value::Value; result_0=nothing::Union{Nothing, IR.Type}, edge_padding_low, edge_padding_high, interior_padding, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, padding_value, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("edge_padding_low", edge_padding_low), namedattribute("edge_padding_high", edge_padding_high), namedattribute("interior_padding", interior_padding), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.pad", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`partition_id` + +Produces `partition_id` of the current process. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#partition_id + +# Example +```mlir +%result = stablehlo.partition_id : tensor +``` +""" +function partition_id(; result_0=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.partition_id", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`popcnt` + +Performs element-wise count of the number of bits set in the `operand` +tensor and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt + +# Example +```mlir +%result = stablehlo.popcnt %operand : tensor<4xi64> +``` +""" +function popcnt(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.popcnt", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`power` + +Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + +# Example +```mlir +%result = stablehlo.power %lhs, %rhs : tensor<6xf64> +``` +""" +function power(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.power", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`real_dynamic_slice` + +This operation is a work in progress, so it is not yet included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. + +Informally, this operation does the same thing as SliceOp except +that `start_indices`, `limit_indices` and `strides` are specified dynamically: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice + +# Example +```mlir +%result = stablehlo.real_dynamic_slice %operand, + %start_indices, %limit_indices, %strides + : (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32> +``` +""" +function real_dynamic_slice(operand::Value, start_indices::Value, limit_indices::Value, strides::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices, limit_indices, strides, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.real_dynamic_slice", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`real` + +Extracts the real part, element-wise, from the `operand` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real + +# Example +```mlir +%result = stablehlo.real %operand : (tensor<2xcomplex>) -> tensor<2xf32> +``` +""" +function real(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.real", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`recv` + +Receives data from a channel with `channel_id` and produces `results`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#recv + +# Example +```mlir +%results:2 = \"stablehlo.recv\"(%token) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true +} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token) +``` +""" +function recv(token::Value; result_0::Vector{IR.Type}, channel_handle, is_host_transfer=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("channel_handle", channel_handle), ] + !isnothing(is_host_transfer) && push!(attributes, namedattribute("is_host_transfer", is_host_transfer)) + + create_operation( + "stablehlo.recv", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`reduce` + +Applies a reduction function `body` to `inputs` and `init_values` along the +`dimensions` and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce + +# Example +```mlir +%result = \"stablehlo.reduce\"(%input, %init_value) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor +}) { + dimensions = array +} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> +``` +""" +function reduce(inputs::Vector{Value}, init_values::Vector{Value}; result_0::Vector{IR.Type}, dimensions, body::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., init_values..., ] + owned_regions = Region[body, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimensions", dimensions), ] + + create_operation( + "stablehlo.reduce", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`reduce_precision` + +Performs element-wise conversion of `operand` to another floating-point type +that uses `exponent_bits` and `mantissa_bits` and back to the original +floating-point type and produces an `output` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision + +# Example +```mlir +%output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64> +``` +""" +function reduce_precision(operand::Value; output=nothing::Union{Nothing, IR.Type}, exponent_bits, mantissa_bits, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("exponent_bits", exponent_bits), namedattribute("mantissa_bits", mantissa_bits), ] + !isnothing(output) && push!(op_ty_results, output) + + create_operation( + "stablehlo.reduce_precision", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`reduce_scatter` + +Within each process group in the process grid, performs reduction, using +`computations`, over the values of the `operand` tensor from each process, +splits the reduction result along `scatter_dimension` into parts, and +scatters the split parts between the processes to produce the `result`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_scatter + + Example: + ```mlir + %result = \"stablehlo.reduce_scatter\"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + scatter_dimension = 1 : i64, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<2x4xi64>) -> tensor<2x2xi64> + ``` +""" +function reduce_scatter(operand::Value; result_0::IR.Type, scatter_dimension, replica_groups, channel_handle=nothing, use_global_device_ids=nothing, computation::Region, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("scatter_dimension", scatter_dimension), namedattribute("replica_groups", replica_groups), ] + !isnothing(channel_handle) && push!(attributes, namedattribute("channel_handle", channel_handle)) + !isnothing(use_global_device_ids) && push!(attributes, namedattribute("use_global_device_ids", use_global_device_ids)) + + create_operation( + "stablehlo.reduce_scatter", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`reduce_window` + +Applies a reduction function `body` to windows of `inputs` and `init_values` +and produces `results`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window + +# Example +```mlir +%result = \"stablehlo.reduce_window\"(%input, %init_value) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor +}) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array, + padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> +} : (tensor<3x2xi64>, tensor) -> tensor<2x2xi64> +``` +""" +function reduce_window(inputs::Vector{Value}, init_values::Vector{Value}; result_0::Vector{IR.Type}, window_dimensions, window_strides=nothing, base_dilations=nothing, window_dilations=nothing, padding=nothing, body::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., init_values..., ] + owned_regions = Region[body, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("window_dimensions", window_dimensions), ] + !isnothing(window_strides) && push!(attributes, namedattribute("window_strides", window_strides)) + !isnothing(base_dilations) && push!(attributes, namedattribute("base_dilations", base_dilations)) + !isnothing(window_dilations) && push!(attributes, namedattribute("window_dilations", window_dilations)) + !isnothing(padding) && push!(attributes, namedattribute("padding", padding)) + + create_operation( + "stablehlo.reduce_window", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`remainder` + +Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors +and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder + +# Example +```mlir +%result = stablehlo.remainder %lhs, %rhs : tensor<4xi64> +``` +""" +function remainder(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.remainder", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`replica_id` + +Produces `replica_id` of the current process. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#replica_id + +# Example +```mlir +%result = stablehlo.replica_id : tensor +``` +""" +function replica_id(; result_0=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.replica_id", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`reshape` + +Performs reshape of `operand` tensor to a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape + +# Example +```mlir +%result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32> +``` +""" +function reshape(operand::Value; result_0::IR.Type, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.reshape", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function return_(results::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[results..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.return", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`reverse` + +Reverses the order of elements in the `operand` along the specified +`dimensions` and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reverse + +# Example +```mlir +%result = stablehlo.reverse %operand, dims = [1] : tensor<3x2xi32> +``` +""" +function reverse(operand::Value; result=nothing::Union{Nothing, IR.Type}, dimensions, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimensions", dimensions), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.reverse", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`rng_bit_generator` + +Returns an `output` filled with uniform random data and an updated output +state `output_state` given an initial state `initial_state` using the +pseudorandom number generator algorithm `rng_algorithm`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator + +# Example +```mlir +%output_state, %output = stablehlo.rng_bit_generator %initial_state, algorithm = THREE_FRY : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>) +``` +""" +function rng_bit_generator(initial_state::Value; output_state::IR.Type, output::IR.Type, rng_algorithm, location=Location()) + op_ty_results = IR.Type[output_state, output, ] + operands = Value[initial_state, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("rng_algorithm", rng_algorithm), ] + + create_operation( + "stablehlo.rng_bit_generator", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`rng` + +Generates random numbers using the `rng_distribution` algorithm and produces +a `result` tensor of a given shape `shape`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng + +# Example +```mlir +%result = stablehlo.rng %a, %b, %shape, distribution = NORMAL : (tensor, tensor, tensor<2xi64>) -> tensor<3x3xi32> +``` +""" +function rng(a::Value, b::Value, shape::Value; result=nothing::Union{Nothing, IR.Type}, rng_distribution, location=Location()) + op_ty_results = IR.Type[] + operands = Value[a, b, shape, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("rng_distribution", rng_distribution), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.rng", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`round_nearest_even` + +Performs element-wise rounding towards the nearest integer, breaking ties +towards the even integer, on the `operand` tensor and produces a `result` +tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even + +# Example +```mlir +%result = stablehlo.round_nearest_even %operand : tensor<5xf64> +``` +""" +function round_nearest_even(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.round_nearest_even", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`round_nearest_afz` + +Performs element-wise rounding towards the nearest integer, breaking ties +away from zero, on the `operand` tensor and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz + +# Example +```mlir +%result = stablehlo.round_nearest_afz %operand : tensor<5xf64> +``` +""" +function round_nearest_afz(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.round_nearest_afz", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`rsqrt` + +Performs element-wise reciprocal square root operation on `operand` tensor +and produces a `result` tensor, implementing the `rSqrt` operation from the +IEEE-754 specification. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt + +# Example +```mlir +%result = stablehlo.rsqrt %operand : tensor<2x2xf32> +``` +""" +function rsqrt(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.rsqrt", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`scatter` + +Produces `results` tensors which are equal to `inputs` tensors except that +several slices specified by `scatter_indices` are updated with the values +`updates` using `update_computation`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + + Example: + ```mlir + %result = \"stablehlo.scatter\"(%input, %scatter_indices, %update) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + ``` +""" +function scatter(inputs::Vector{Value}, scatter_indices::Value, updates::Vector{Value}; result_0::Vector{IR.Type}, scatter_dimension_numbers, indices_are_sorted=nothing, unique_indices=nothing, update_computation::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., scatter_indices, updates..., ] + owned_regions = Region[update_computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("scatter_dimension_numbers", scatter_dimension_numbers), ] + !isnothing(indices_are_sorted) && push!(attributes, namedattribute("indices_are_sorted", indices_are_sorted)) + !isnothing(unique_indices) && push!(attributes, namedattribute("unique_indices", unique_indices)) + + create_operation( + "stablehlo.scatter", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`select_and_scatter` + +Scatters the values from the `source` tensor using `scatter` based on the +outcome of `reduce_window` of the `input` tensor using `select` and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select_and_scatter + +# Example +```mlir +%result = \"stablehlo.select_and_scatter\"(%operand, %source, %init_value) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.compare GE, %arg0, %arg1 : (tensor, tensor) -> tensor + stablehlo.return %0 : tensor +}, { + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor +}) { + window_dimensions = dense<[3, 1]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>, + padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64> +} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor) -> tensor<4x2xi64> +``` +""" +function select_and_scatter(operand::Value, source::Value, init_value::Value; result_0::IR.Type, window_dimensions=nothing, window_strides=nothing, padding=nothing, select::Region, scatter::Region, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, source, init_value, ] + owned_regions = Region[select, scatter, ] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(window_dimensions) && push!(attributes, namedattribute("window_dimensions", window_dimensions)) + !isnothing(window_strides) && push!(attributes, namedattribute("window_strides", window_strides)) + !isnothing(padding) && push!(attributes, namedattribute("padding", padding)) + + create_operation( + "stablehlo.select_and_scatter", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`select` + +Produces a `result` tensor where each element is selected from `on_true` or +`on_false` tensor based on the value of the corresponding element of `pred`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select + +# Example +```mlir +%result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32> +``` +""" +function select(pred::Value, on_true::Value, on_false::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[pred, on_true, on_false, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.select", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`send` + +Sends `inputs` to a channel `channel_id` and produces a `result` token. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#send + +# Example +```mlir +%result = \"stablehlo.send\"(%operand, %token) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true +} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token +``` +""" +function send(inputs::Vector{Value}, token::Value; result_0=nothing::Union{Nothing, IR.Type}, channel_handle, is_host_transfer=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[inputs..., token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("channel_handle", channel_handle), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + !isnothing(is_host_transfer) && push!(attributes, namedattribute("is_host_transfer", is_host_transfer)) + + create_operation( + "stablehlo.send", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`set_dimension_size` + +This operation is a work in progress, so it is not yet included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. + +Informally, this operation does the same thing as XLA\'s SetDimensionSize: +https://www.tensorflow.org/xla/operation_semantics#setdimensionsize + +# Example +```mlir +%0 = stablehlo.set_dimension_size %arg0, %arg1, dim = 1 : (tensor<4x2xf32>, tensor) -> tensor<4x2xf32> +``` +""" +function set_dimension_size(operand::Value, size::Value; result_0=nothing::Union{Nothing, IR.Type}, dimension, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, size, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.set_dimension_size", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`shift_left` + +Performs element-wise left-shift operation on the `lhs` tensor by `rhs` +number of bits and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left + +# Example +```mlir +%result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64> +``` +""" +function shift_left(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.shift_left", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`shift_right_arithmetic` + +Performs element-wise arithmetic right-shift operation on the `lhs` tensor +by `rhs` number of bits and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmetic + +# Example +```mlir +%result = stablehlo.shift_right_arithmetic %lhs, %rhs : tensor<3xi64> +``` +""" +function shift_right_arithmetic(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.shift_right_arithmetic", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`shift_right_logical` + +Performs element-wise logical right-shift operation on the `lhs` tensor by +`rhs` number of bits and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical + +# Example +```mlir +%result = stablehlo.shift_right_logical %lhs, %rhs : tensor<3xi64> +``` +""" +function shift_right_logical(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.shift_right_logical", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`sign` + +Returns the sign of the `operand` element-wise and produces a `result` +tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign + +# Example +```mlir +%result = stablehlo.sign %operand : tensor<5xf64> +``` +""" +function sign(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.sign", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`sine` + +Performs element-wise sine operation on `operand` tensor and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine + +# Example +```mlir +%result = stablehlo.sine %operand : tensor<2xf32> +``` +""" +function sine(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.sine", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`slice` + +Extracts a slice from the `operand` using statically-computed starting +indices and produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice + +# Example +```mlir +%result = stablehlo.slice %operand [1:3, 4:8:2] + : (tensor<3x8xi64>) -> tensor<2x2xi64> + +// Same in generic form: the `1:3` above is mapped to the first entry in +// `start_indices` and `limit_indices`, while `strides` is implicitly 1. +// The `4:8:2` above is parsed into the second entry of `start_indices`, +// `limit_indices` and `strides` respectively. +%result = \"stablehlo.slice\" (%operand) { + start_indices = array, + limit_indices = array, + strides = array +} : (tensor<3x8xi64>) -> tensor<2x2xi64> +``` +""" +function slice(operand::Value; result_0=nothing::Union{Nothing, IR.Type}, start_indices, limit_indices, strides, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("start_indices", start_indices), namedattribute("limit_indices", limit_indices), namedattribute("strides", strides), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.slice", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`sort` + +Sorts a variadic number of tensors in `inputs` together, according to a +custom `comparator`, along the given `dimension` and produces a variadic +number of tensors as `results`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sort + +# Example +```mlir +%result0, %result1 = \"stablehlo.sort\"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %predicate = stablehlo.compare GT, %arg0, %arg1 : (tensor, tensor) -> tensor + stablehlo.return %predicate : tensor +}) { + dimension = 0 : i64, + is_stable = true +} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>) +""" +function sort(inputs::Vector{Value}; result_0::Vector{IR.Type}, dimension=nothing, is_stable=nothing, comparator::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., ] + owned_regions = Region[comparator, ] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(dimension) && push!(attributes, namedattribute("dimension", dimension)) + !isnothing(is_stable) && push!(attributes, namedattribute("is_stable", is_stable)) + + create_operation( + "stablehlo.sort", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`sqrt` + +Performs element-wise square root operation on `operand` tensor and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt + +# Example +```mlir +%result = stablehlo.sqrt %operand : tensor<2x2xf32> +``` +""" +function sqrt(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.sqrt", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`subtract` + +Performs element-wise subtraction of two tensors `lhs` and `rhs` and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#subtract + +# Example +```mlir +%result = stablehlo.subtract %lhs, %rhs : tensor<2xi32> +``` +""" +function subtract(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.subtract", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`tan` + +Performs element-wise tangent operation on `operand` tensor and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan + +# Example +```mlir +%result = stablehlo.tan %operand : tensor<2x2xf64> +``` +""" +function tan(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.tan", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`tanh` + +Performs element-wise hyperbolic tangent operation on `operand` tensor and +produces a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh + +# Example +```mlir +%result = stablehlo.tanh %operand : tensor<2xf32> +``` +""" +function tanh(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.tanh", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`torch_index_select` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as PyTorch\'s index_select, +augmented with support for batch dimensions: +https://pytorch.org/docs/stable/generated/torch.index_select.html. + +The `batch_dims` attribute specifies the number of major batch dimensions +(0 or more) that act like a multidimensional loop over both the operand and +the index. + +# Example +```mlir +%result = \"stablehlo.torch_index_select\"(%operand, %index) { + dim = 2 : i64, + batch_dims = 1 : i64 +} : (tensor<8x128x3072x64xf32>, tensor<8x16x1024xi32>) -> tensor<8x128x16x1024x64xf32> +``` +""" +function torch_index_select(operand::Value, index::Value; result_0::IR.Type, dim, batch_dims, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, index, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dim", dim), namedattribute("batch_dims", batch_dims), ] + + create_operation( + "stablehlo.torch_index_select", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`transpose` + +Permutes the dimensions of `operand` tensor using `permutation` and produces +a `result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose + +# Example +```mlir +%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<1x2x3xi32>) -> tensor<3x2x1xi32> +``` +""" +function transpose(operand::Value; result=nothing::Union{Nothing, IR.Type}, permutation, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("permutation", permutation), ] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.transpose", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`triangular_solve` + +Solves batches of systems of linear equations with lower or upper triangular +coefficient matrices. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#triangular_solve + +# Example +```mlir +%result = \"stablehlo.triangular_solve\"(%a, %b) { + left_side = true, + lower = true, + unit_diagonal = false, + transpose_a = #stablehlo +} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +``` +""" +function triangular_solve(a::Value, b::Value; result_0=nothing::Union{Nothing, IR.Type}, left_side, lower, unit_diagonal, transpose_a, location=Location()) + op_ty_results = IR.Type[] + operands = Value[a, b, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("left_side", left_side), namedattribute("lower", lower), namedattribute("unit_diagonal", unit_diagonal), namedattribute("transpose_a", transpose_a), ] + !isnothing(result_0) && push!(op_ty_results, result_0) + + create_operation( + "stablehlo.triangular_solve", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`tuple` + +Produces a `result` tuple from values `val`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tuple + +# Example +```mlir +%result = stablehlo.tuple %val0, %val1 : tuple, tuple>> +``` +""" +function tuple(val::Vector{Value}; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[val..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.tuple", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`unary_einsum` + +This operation is on its way out of StableHLO, so it is not included in +the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. + +Informally, this operation does the same thing as TF\'s einsum: +https://www.tensorflow.org/api_docs/python/tf/einsum + +# Example +```mlir +%result = \"stablehlo.unary_einsum\"(%operand) { + einsum_config = \"ab->a\" +} : (tensor<4x16xf32>) -> tensor<4xf32> +``` +""" +function unary_einsum(operand::Value; result_0::IR.Type, einsum_config, location=Location()) + op_ty_results = IR.Type[result_0, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("einsum_config", einsum_config), ] + + create_operation( + "stablehlo.unary_einsum", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`uniform_dequantize` + +Performs element-wise conversion of quantized tensor `operand` to a +floating-point tensor `result` according to the quantization parameters +defined by the `operand` type. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize + +# Example +```mlir +%result = stablehlo.uniform_dequantize %operand : (tensor<2x!quant.uniform>) -> tensor<2xf32> +``` +""" +function uniform_dequantize(operand::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.uniform_dequantize", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +""" +`uniform_quantize` + +Performs element-wise conversion of floating-point tensor or quantized +tensor `operand` to a quantized tensor `result` according to the +quantization parameters defined by the `result` type. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize + +# Example +```mlir +%result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform> +``` +""" +function uniform_quantize(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.uniform_quantize", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`while_` + +Produces the output from executing `body` function 0 or more times while the +`cond` function outputs `true`. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while + +# Example +```mlir +%results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor, tensor +cond { + %cond = stablehlo.compare LT, %arg0, %ten : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor +} do { + %new_sum = stablehlo.add %arg1, %one : tensor + %new_i = stablehlo.add %arg0, %one : tensor + stablehlo.return %new_i, %new_sum : tensor, tensor +} +``` +""" +function while_(operand::Vector{Value}; result_0::Vector{IR.Type}, cond::Region, body::Region, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[operand..., ] + owned_regions = Region[cond, body, ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "stablehlo.while", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`xor` + +Performs element-wise XOR of two tensors `lhs` and `rhs` and produces a +`result` tensor. + +See: +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#xor + +# Example +```mlir +%result = stablehlo.xor %lhs, %rhs : tensor<2xi32> +``` +""" +function xor(lhs::Value, rhs::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + create_operation( + "stablehlo.xor", location; + operands, owned_regions, successors, attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false) + ) +end + +end # stablehlo diff --git a/src/mlir/Dialects/VHLO.jl b/src/mlir/Dialects/VHLO.jl new file mode 100755 index 000000000..042abb9cd --- /dev/null +++ b/src/mlir/Dialects/VHLO.jl @@ -0,0 +1,2008 @@ +module vhlo +using ...IR +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + + + +function abs_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.abs_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function add_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.add_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function after_all_v1(inputs::Vector{Value}; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.after_all_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function all_gather_v1(operand::Value; result::IR.Type, all_gather_dim, replica_groups, channel_id, use_global_device_ids, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("all_gather_dim", all_gather_dim), namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), namedattribute("use_global_device_ids", use_global_device_ids), ] + + create_operation( + "vhlo.all_gather_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function all_gather_v2(operands::Vector{Value}; results::Vector{IR.Type}, all_gather_dim, replica_groups, channel_id, use_global_device_ids, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("all_gather_dim", all_gather_dim), namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), namedattribute("use_global_device_ids", use_global_device_ids), ] + + create_operation( + "vhlo.all_gather_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function all_reduce_v1(operand::Value; result::IR.Type, replica_groups, channel_id, use_global_device_ids, computation::Region, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), namedattribute("use_global_device_ids", use_global_device_ids), ] + + create_operation( + "vhlo.all_reduce_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function all_reduce_v2(operands::Vector{Value}; results::Vector{IR.Type}, replica_groups, channel_id, use_global_device_ids, computation::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operands..., ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), namedattribute("use_global_device_ids", use_global_device_ids), ] + + create_operation( + "vhlo.all_reduce_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function all_to_all_v1(operand::Value; result::IR.Type, split_dimension, concat_dimension, split_count, replica_groups, channel_id, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("split_dimension", split_dimension), namedattribute("concat_dimension", concat_dimension), namedattribute("split_count", split_count), namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), ] + + create_operation( + "vhlo.all_to_all_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function all_to_all_v2(operands::Vector{Value}; results::Vector{IR.Type}, split_dimension, concat_dimension, split_count, replica_groups, channel_id, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("split_dimension", split_dimension), namedattribute("concat_dimension", concat_dimension), namedattribute("split_count", split_count), namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), ] + + create_operation( + "vhlo.all_to_all_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function and_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.and_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function atan2_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.atan2_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function batch_norm_grad_v1(operand::Value, scale::Value, mean::Value, variance::Value, grad_output::Value; grad_operand::IR.Type, grad_scale::IR.Type, grad_offset::IR.Type, epsilon, feature_index, location=Location()) + op_ty_results = IR.Type[grad_operand, grad_scale, grad_offset, ] + operands = Value[operand, scale, mean, variance, grad_output, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("epsilon", epsilon), namedattribute("feature_index", feature_index), ] + + create_operation( + "vhlo.batch_norm_grad_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function batch_norm_inference_v1(operand::Value, scale::Value, offset::Value, mean::Value, variance::Value; result::IR.Type, epsilon, feature_index, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, scale, offset, mean, variance, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("epsilon", epsilon), namedattribute("feature_index", feature_index), ] + + create_operation( + "vhlo.batch_norm_inference_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function batch_norm_training_v1(operand::Value, scale::Value, offset::Value; output::IR.Type, batch_mean::IR.Type, batch_var::IR.Type, epsilon, feature_index, location=Location()) + op_ty_results = IR.Type[output, batch_mean, batch_var, ] + operands = Value[operand, scale, offset, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("epsilon", epsilon), namedattribute("feature_index", feature_index), ] + + create_operation( + "vhlo.batch_norm_training_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function bitcast_convert_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.bitcast_convert_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function broadcast_in_dim_v1(operand::Value; result::IR.Type, broadcast_dimensions, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("broadcast_dimensions", broadcast_dimensions), ] + + create_operation( + "vhlo.broadcast_in_dim_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function broadcast_v1(operand::Value; result::IR.Type, broadcast_sizes, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("broadcast_sizes", broadcast_sizes), ] + + create_operation( + "vhlo.broadcast_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function call_v1(operands::Vector{Value}; results::Vector{IR.Type}, callee, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operands..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("callee", callee), ] + + create_operation( + "vhlo.call_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function case_v1(index::Value; results::Vector{IR.Type}, branches::Vector{Region}, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[index, ] + owned_regions = Region[branches..., ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.case_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function cbrt_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.cbrt_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function ceil_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.ceil_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function cholesky_v1(a::Value; result::IR.Type, lower, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[a, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("lower", lower), ] + + create_operation( + "vhlo.cholesky_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function clamp_v1(min::Value, operand::Value, max::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[min, operand, max, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.clamp_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function count_leading_zeros_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.count_leading_zeros_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function collective_broadcast_v1(operand::Value; result::IR.Type, replica_groups, channel_id, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), ] + + create_operation( + "vhlo.collective_broadcast_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function collective_permute_v1(operand::Value; result::IR.Type, source_target_pairs, channel_id, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("source_target_pairs", source_target_pairs), namedattribute("channel_id", channel_id), ] + + create_operation( + "vhlo.collective_permute_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function compare_v1(lhs::Value, rhs::Value; result::IR.Type, comparison_direction, compare_type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("comparison_direction", comparison_direction), namedattribute("compare_type", compare_type), ] + + create_operation( + "vhlo.compare_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function complex_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.complex_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function composite_v1(inputs::Vector{Value}; results::Vector{IR.Type}, name, composite_attributes, decomposition, version, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("name", name), namedattribute("composite_attributes", composite_attributes), namedattribute("decomposition", decomposition), namedattribute("version", version), ] + + create_operation( + "vhlo.composite_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function concatenate_v1(inputs::Vector{Value}; result::IR.Type, dimension, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), ] + + create_operation( + "vhlo.concatenate_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function constant_v1(; output::IR.Type, value, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("value", value), ] + + create_operation( + "vhlo.constant_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function convert_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.convert_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function convolution_v1(lhs::Value, rhs::Value; result::IR.Type, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("window_strides", window_strides), namedattribute("padding", padding), namedattribute("lhs_dilation", lhs_dilation), namedattribute("rhs_dilation", rhs_dilation), namedattribute("window_reversal", window_reversal), namedattribute("input_batch_dimension", input_batch_dimension), namedattribute("input_feature_dimension", input_feature_dimension), namedattribute("input_spatial_dimensions", input_spatial_dimensions), namedattribute("kernel_input_feature_dimension", kernel_input_feature_dimension), namedattribute("kernel_output_feature_dimension", kernel_output_feature_dimension), namedattribute("kernel_spatial_dimensions", kernel_spatial_dimensions), namedattribute("output_batch_dimension", output_batch_dimension), namedattribute("output_feature_dimension", output_feature_dimension), namedattribute("output_spatial_dimensions", output_spatial_dimensions), namedattribute("feature_group_count", feature_group_count), namedattribute("batch_group_count", batch_group_count), namedattribute("precision_config", precision_config), ] + + create_operation( + "vhlo.convolution_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function cosine_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.cosine_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function create_token_v1(; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.create_token_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function cross_replica_sum_v1(operand::Value; result::IR.Type, replica_groups, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("replica_groups", replica_groups), ] + + create_operation( + "vhlo.cross-replica-sum_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function custom_call_v1(inputs::Vector{Value}; results::Vector{IR.Type}, call_target_name, has_side_effect, backend_config, api_version, called_computations, operand_layouts, result_layouts, output_operand_aliases, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("call_target_name", call_target_name), namedattribute("has_side_effect", has_side_effect), namedattribute("backend_config", backend_config), namedattribute("api_version", api_version), namedattribute("called_computations", called_computations), namedattribute("operand_layouts", operand_layouts), namedattribute("result_layouts", result_layouts), namedattribute("output_operand_aliases", output_operand_aliases), ] + + create_operation( + "vhlo.custom_call_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function divide_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.divide_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dot_general_v1(lhs::Value, rhs::Value; result::IR.Type, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("lhs_batching_dimensions", lhs_batching_dimensions), namedattribute("rhs_batching_dimensions", rhs_batching_dimensions), namedattribute("lhs_contracting_dimensions", lhs_contracting_dimensions), namedattribute("rhs_contracting_dimensions", rhs_contracting_dimensions), namedattribute("precision_config", precision_config), ] + + create_operation( + "vhlo.dot_general_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dot_general_v2(lhs::Value, rhs::Value; result::IR.Type, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config, lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations, allow_imprecise_accumulation, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("lhs_batching_dimensions", lhs_batching_dimensions), namedattribute("rhs_batching_dimensions", rhs_batching_dimensions), namedattribute("lhs_contracting_dimensions", lhs_contracting_dimensions), namedattribute("rhs_contracting_dimensions", rhs_contracting_dimensions), namedattribute("precision_config", precision_config), namedattribute("lhs_precision_type", lhs_precision_type), namedattribute("rhs_precision_type", rhs_precision_type), namedattribute("accumulation_type", accumulation_type), namedattribute("lhs_component_count", lhs_component_count), namedattribute("rhs_component_count", rhs_component_count), namedattribute("num_primitive_operations", num_primitive_operations), namedattribute("allow_imprecise_accumulation", allow_imprecise_accumulation), ] + + create_operation( + "vhlo.dot_general_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dot_v1(lhs::Value, rhs::Value; result::IR.Type, precision_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("precision_config", precision_config), ] + + create_operation( + "vhlo.dot_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_broadcast_in_dim_v1(operand::Value, output_dimensions::Value; result::IR.Type, broadcast_dimensions, known_expanding_dimensions, known_nonexpanding_dimensions, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, output_dimensions, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("broadcast_dimensions", broadcast_dimensions), namedattribute("known_expanding_dimensions", known_expanding_dimensions), namedattribute("known_nonexpanding_dimensions", known_nonexpanding_dimensions), ] + + create_operation( + "vhlo.dynamic_broadcast_in_dim_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_conv_v1(lhs::Value, rhs::Value, d_padding::Value; result::IR.Type, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, d_padding, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("window_strides", window_strides), namedattribute("padding", padding), namedattribute("lhs_dilation", lhs_dilation), namedattribute("rhs_dilation", rhs_dilation), namedattribute("window_reversal", window_reversal), namedattribute("input_batch_dimension", input_batch_dimension), namedattribute("input_feature_dimension", input_feature_dimension), namedattribute("input_spatial_dimensions", input_spatial_dimensions), namedattribute("kernel_input_feature_dimension", kernel_input_feature_dimension), namedattribute("kernel_output_feature_dimension", kernel_output_feature_dimension), namedattribute("kernel_spatial_dimensions", kernel_spatial_dimensions), namedattribute("output_batch_dimension", output_batch_dimension), namedattribute("output_feature_dimension", output_feature_dimension), namedattribute("output_spatial_dimensions", output_spatial_dimensions), namedattribute("feature_group_count", feature_group_count), namedattribute("batch_group_count", batch_group_count), namedattribute("precision_config", precision_config), ] + + create_operation( + "vhlo.dynamic_conv_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_conv_v2(lhs::Value, rhs::Value, padding::Value; result::IR.Type, window_strides, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, padding, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("window_strides", window_strides), namedattribute("lhs_dilation", lhs_dilation), namedattribute("rhs_dilation", rhs_dilation), namedattribute("window_reversal", window_reversal), namedattribute("input_batch_dimension", input_batch_dimension), namedattribute("input_feature_dimension", input_feature_dimension), namedattribute("input_spatial_dimensions", input_spatial_dimensions), namedattribute("kernel_input_feature_dimension", kernel_input_feature_dimension), namedattribute("kernel_output_feature_dimension", kernel_output_feature_dimension), namedattribute("kernel_spatial_dimensions", kernel_spatial_dimensions), namedattribute("output_batch_dimension", output_batch_dimension), namedattribute("output_feature_dimension", output_feature_dimension), namedattribute("output_spatial_dimensions", output_spatial_dimensions), namedattribute("feature_group_count", feature_group_count), namedattribute("batch_group_count", batch_group_count), namedattribute("precision_config", precision_config), ] + + create_operation( + "vhlo.dynamic_conv_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_gather_v1(operand::Value, start_indices::Value, slice_sizes::Value; result::IR.Type, offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim, indices_are_sorted, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices, slice_sizes, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("offset_dims", offset_dims), namedattribute("collapsed_slice_dims", collapsed_slice_dims), namedattribute("start_index_map", start_index_map), namedattribute("index_vector_dim", index_vector_dim), namedattribute("indices_are_sorted", indices_are_sorted), ] + + create_operation( + "vhlo.dynamic_gather_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_gather_v2(operand::Value, start_indices::Value, slice_sizes::Value; result::IR.Type, offset_dims, collapsed_slice_dims, operand_batching_dims, start_indices_batching_dims, start_index_map, index_vector_dim, indices_are_sorted, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices, slice_sizes, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("offset_dims", offset_dims), namedattribute("collapsed_slice_dims", collapsed_slice_dims), namedattribute("operand_batching_dims", operand_batching_dims), namedattribute("start_indices_batching_dims", start_indices_batching_dims), namedattribute("start_index_map", start_index_map), namedattribute("index_vector_dim", index_vector_dim), namedattribute("indices_are_sorted", indices_are_sorted), ] + + create_operation( + "vhlo.dynamic_gather_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_iota_v1(output_shape::Value; result::IR.Type, iota_dimension, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[output_shape, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("iota_dimension", iota_dimension), ] + + create_operation( + "vhlo.dynamic_iota_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_pad_v1(operand::Value, padding_value::Value, edge_padding_low::Value, edge_padding_high::Value, interior_padding::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, padding_value, edge_padding_low, edge_padding_high, interior_padding, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.dynamic_pad_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_reshape_v1(operand::Value, output_shape::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, output_shape, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.dynamic_reshape_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_slice_v1(operand::Value, start_indices::Vector{Value}; result::IR.Type, slice_sizes, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("slice_sizes", slice_sizes), ] + + create_operation( + "vhlo.dynamic_slice_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function dynamic_update_slice_v1(operand::Value, update::Value, start_indices::Vector{Value}; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, update, start_indices..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.dynamic_update_slice_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function einsum_v1(lhs::Value, rhs::Value; result::IR.Type, einsum_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("einsum_config", einsum_config), ] + + create_operation( + "vhlo.einsum_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function exponential_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.exponential_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function exponential_minus_one_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.exponential_minus_one_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function fft_v1(operand::Value; result::IR.Type, fft_type, fft_length, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fft_type", fft_type), namedattribute("fft_length", fft_length), ] + + create_operation( + "vhlo.fft_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function floor_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.floor_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function func_v1(; sym_name, function_type, sym_visibility, arg_attrs, res_attrs, body::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[body, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name), namedattribute("function_type", function_type), namedattribute("sym_visibility", sym_visibility), namedattribute("arg_attrs", arg_attrs), namedattribute("res_attrs", res_attrs), ] + + create_operation( + "vhlo.func_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function gather_v1(operand::Value, start_indices::Value; result::IR.Type, offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim, slice_sizes, indices_are_sorted, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("offset_dims", offset_dims), namedattribute("collapsed_slice_dims", collapsed_slice_dims), namedattribute("start_index_map", start_index_map), namedattribute("index_vector_dim", index_vector_dim), namedattribute("slice_sizes", slice_sizes), namedattribute("indices_are_sorted", indices_are_sorted), ] + + create_operation( + "vhlo.gather_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function gather_v2(operand::Value, start_indices::Value; result::IR.Type, offset_dims, collapsed_slice_dims, operand_batching_dims, start_indices_batching_dims, start_index_map, index_vector_dim, slice_sizes, indices_are_sorted, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("offset_dims", offset_dims), namedattribute("collapsed_slice_dims", collapsed_slice_dims), namedattribute("operand_batching_dims", operand_batching_dims), namedattribute("start_indices_batching_dims", start_indices_batching_dims), namedattribute("start_index_map", start_index_map), namedattribute("index_vector_dim", index_vector_dim), namedattribute("slice_sizes", slice_sizes), namedattribute("indices_are_sorted", indices_are_sorted), ] + + create_operation( + "vhlo.gather_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function get_dimension_size_v1(operand::Value; result::IR.Type, dimension, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), ] + + create_operation( + "vhlo.get_dimension_size_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function get_tuple_element_v1(operand::Value; result::IR.Type, index, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("index", index), ] + + create_operation( + "vhlo.get_tuple_element_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function if_v1(pred::Value; results::Vector{IR.Type}, true_branch::Region, false_branch::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[pred, ] + owned_regions = Region[true_branch, false_branch, ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.if_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function imag_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.imag_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function infeed_v1(token::Value; results::Vector{IR.Type}, infeed_config, layout, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("infeed_config", infeed_config), namedattribute("layout", layout), ] + + create_operation( + "vhlo.infeed_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function iota_v1(; output::IR.Type, iota_dimension, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("iota_dimension", iota_dimension), ] + + create_operation( + "vhlo.iota_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function is_finite_v1(x::Value; y::IR.Type, location=Location()) + op_ty_results = IR.Type[y, ] + operands = Value[x, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.is_finite_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function log_plus_one_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.log_plus_one_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function log_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.log_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function logistic_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.logistic_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function map_v1(inputs::Vector{Value}; result::IR.Type, dimensions, computation::Region, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[inputs..., ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimensions", dimensions), ] + + create_operation( + "vhlo.map_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function maximum_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.maximum_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function minimum_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.minimum_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function multiply_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.multiply_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function negate_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.negate_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function not_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.not_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function optimization_barrier_v1(operand::Vector{Value}; result::Vector{IR.Type}, location=Location()) + op_ty_results = IR.Type[result..., ] + operands = Value[operand..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.optimization_barrier_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function or_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.or_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function outfeed_v1(inputs::Vector{Value}, token::Value; result::IR.Type, outfeed_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[inputs..., token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("outfeed_config", outfeed_config), ] + + create_operation( + "vhlo.outfeed_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function pad_v1(operand::Value, padding_value::Value; result::IR.Type, edge_padding_low, edge_padding_high, interior_padding, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, padding_value, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("edge_padding_low", edge_padding_low), namedattribute("edge_padding_high", edge_padding_high), namedattribute("interior_padding", interior_padding), ] + + create_operation( + "vhlo.pad_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function partition_id_v1(; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.partition_id_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function popcnt_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.popcnt_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function power_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.power_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function real_dynamic_slice_v1(operand::Value, start_indices::Value, limit_indices::Value, strides::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, start_indices, limit_indices, strides, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.real_dynamic_slice_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function real_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.real_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function recv_v1(token::Value; results::Vector{IR.Type}, channel_id, channel_type, is_host_transfer, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("channel_id", channel_id), namedattribute("channel_type", channel_type), namedattribute("is_host_transfer", is_host_transfer), ] + + create_operation( + "vhlo.recv_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function reduce_v1(inputs::Vector{Value}, init_values::Vector{Value}; results::Vector{IR.Type}, dimensions, body::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., init_values..., ] + owned_regions = Region[body, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimensions", dimensions), ] + + create_operation( + "vhlo.reduce_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function reduce_precision_v1(operand::Value; output::IR.Type, exponent_bits, mantissa_bits, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("exponent_bits", exponent_bits), namedattribute("mantissa_bits", mantissa_bits), ] + + create_operation( + "vhlo.reduce_precision_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function reduce_scatter_v1(operand::Value; result::IR.Type, scatter_dimension, replica_groups, channel_id, use_global_device_ids, computation::Region, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("scatter_dimension", scatter_dimension), namedattribute("replica_groups", replica_groups), namedattribute("channel_id", channel_id), namedattribute("use_global_device_ids", use_global_device_ids), ] + + create_operation( + "vhlo.reduce_scatter_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function reduce_window_v1(inputs::Vector{Value}, init_values::Vector{Value}; results::Vector{IR.Type}, window_dimensions, window_strides, base_dilations, window_dilations, padding, body::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., init_values..., ] + owned_regions = Region[body, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("window_dimensions", window_dimensions), namedattribute("window_strides", window_strides), namedattribute("base_dilations", base_dilations), namedattribute("window_dilations", window_dilations), namedattribute("padding", padding), ] + + create_operation( + "vhlo.reduce_window_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function remainder_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.remainder_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function replica_id_v1(; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.replica_id_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function reshape_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.reshape_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function return_v1(results::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[results..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.return_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function reverse_v1(operand::Value; result::IR.Type, dimensions, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimensions", dimensions), ] + + create_operation( + "vhlo.reverse_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function rng_bit_generator_v1(initial_state::Value; output_state::IR.Type, output::IR.Type, rng_algorithm, location=Location()) + op_ty_results = IR.Type[output_state, output, ] + operands = Value[initial_state, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("rng_algorithm", rng_algorithm), ] + + create_operation( + "vhlo.rng_bit_generator_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function rng_v1(a::Value, b::Value, shape::Value; result::IR.Type, rng_distribution, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[a, b, shape, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("rng_distribution", rng_distribution), ] + + create_operation( + "vhlo.rng_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function round_nearest_even_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.round_nearest_even_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function round_nearest_afz_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.round_nearest_afz_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function rsqrt_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.rsqrt_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function scatter_v1(inputs::Vector{Value}, scatter_indices::Value, updates::Vector{Value}; results::Vector{IR.Type}, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim, indices_are_sorted, unique_indices, update_computation::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., scatter_indices, updates..., ] + owned_regions = Region[update_computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("update_window_dims", update_window_dims), namedattribute("inserted_window_dims", inserted_window_dims), namedattribute("scatter_dims_to_operand_dims", scatter_dims_to_operand_dims), namedattribute("index_vector_dim", index_vector_dim), namedattribute("indices_are_sorted", indices_are_sorted), namedattribute("unique_indices", unique_indices), ] + + create_operation( + "vhlo.scatter_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function scatter_v2(inputs::Vector{Value}, scatter_indices::Value, updates::Vector{Value}; results::Vector{IR.Type}, update_window_dims, inserted_window_dims, input_batching_dims, scatter_indices_batching_dims, scatter_dims_to_operand_dims, index_vector_dim, indices_are_sorted, unique_indices, update_computation::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., scatter_indices, updates..., ] + owned_regions = Region[update_computation, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("update_window_dims", update_window_dims), namedattribute("inserted_window_dims", inserted_window_dims), namedattribute("input_batching_dims", input_batching_dims), namedattribute("scatter_indices_batching_dims", scatter_indices_batching_dims), namedattribute("scatter_dims_to_operand_dims", scatter_dims_to_operand_dims), namedattribute("index_vector_dim", index_vector_dim), namedattribute("indices_are_sorted", indices_are_sorted), namedattribute("unique_indices", unique_indices), ] + + create_operation( + "vhlo.scatter_v2", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function select_and_scatter_v1(operand::Value, source::Value, init_value::Value; result::IR.Type, window_dimensions, window_strides, padding, select::Region, scatter::Region, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, source, init_value, ] + owned_regions = Region[select, scatter, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("window_dimensions", window_dimensions), namedattribute("window_strides", window_strides), namedattribute("padding", padding), ] + + create_operation( + "vhlo.select_and_scatter_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function select_v1(pred::Value, on_true::Value, on_false::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[pred, on_true, on_false, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.select_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function send_v1(inputs::Vector{Value}, token::Value; result::IR.Type, channel_id, channel_type, is_host_transfer, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[inputs..., token, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("channel_id", channel_id), namedattribute("channel_type", channel_type), namedattribute("is_host_transfer", is_host_transfer), ] + + create_operation( + "vhlo.send_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function set_dimension_size_v1(operand::Value, size::Value; result::IR.Type, dimension, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, size, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), ] + + create_operation( + "vhlo.set_dimension_size_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function shift_left_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.shift_left_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function shift_right_arithmetic_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.shift_right_arithmetic_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function shift_right_logical_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.shift_right_logical_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function sign_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.sign_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function sine_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.sine_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function slice_v1(operand::Value; result::IR.Type, start_indices, limit_indices, strides, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("start_indices", start_indices), namedattribute("limit_indices", limit_indices), namedattribute("strides", strides), ] + + create_operation( + "vhlo.slice_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function sort_v1(inputs::Vector{Value}; results::Vector{IR.Type}, dimension, is_stable, comparator::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[inputs..., ] + owned_regions = Region[comparator, ] + successors = Block[] + attributes = NamedAttribute[namedattribute("dimension", dimension), namedattribute("is_stable", is_stable), ] + + create_operation( + "vhlo.sort_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function sqrt_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.sqrt_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function subtract_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.subtract_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function tan_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.tan_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function tanh_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.tanh_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function torch_index_select_v1(operand::Value, index::Value; result::IR.Type, dim, batch_dims, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, index, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dim", dim), namedattribute("batch_dims", batch_dims), ] + + create_operation( + "vhlo.torch_index_select_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function transpose_v1(operand::Value; result::IR.Type, permutation, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("permutation", permutation), ] + + create_operation( + "vhlo.transpose_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function triangular_solve_v1(a::Value, b::Value; result::IR.Type, left_side, lower, unit_diagonal, transpose_a, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[a, b, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("left_side", left_side), namedattribute("lower", lower), namedattribute("unit_diagonal", unit_diagonal), namedattribute("transpose_a", transpose_a), ] + + create_operation( + "vhlo.triangular_solve_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function tuple_v1(val::Vector{Value}; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[val..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.tuple_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function unary_einsum_v1(operand::Value; result::IR.Type, einsum_config, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("einsum_config", einsum_config), ] + + create_operation( + "vhlo.unary_einsum_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function uniform_dequantize_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.uniform_dequantize_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function uniform_quantize_v1(operand::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[operand, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.uniform_quantize_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function while_v1(operand::Vector{Value}; results::Vector{IR.Type}, cond::Region, body::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operand..., ] + owned_regions = Region[cond, body, ] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.while_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + + +function xor_v1(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[lhs, rhs, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + create_operation( + "vhlo.xor_v1", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +end # vhlo