Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Ops.batch #535

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
12 changes: 10 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ examples = [
pages = [
"Reactant.jl" => "index.md",
"Introduction" => ["Getting Started" => "introduction/index.md"],
"Tutorials" =>
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
"Tutorials" => [
"Overview" => "tutorials/index.md",
"Profiling" => "tutorials/profiling.md",
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
Expand All @@ -38,6 +41,11 @@ pages = [
"Func" => "api/func.md",
"StableHLO" => "api/stablehlo.md",
"VHLO" => "api/vhlo.md",
"GPU" => "api/gpu.md",
"LLVM" => "api/llvm.md",
"NVVM" => "api/nvvm.md",
"TPU" => "api/tpu.md",
"Triton" => "api/triton.md",
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
Expand Down
10 changes: 9 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ export default defineConfig({
{
text: "Tutorials",
items: [
{text: "Overview", link: "/tutorials/"},
{ text: "Overview", link: "/tutorials/" },
{text: "Profiling", link: "/tutorials/profiling"},
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching"
},
],
},
{
Expand Down Expand Up @@ -112,6 +116,10 @@ export default defineConfig({
items: [
{ text: "Overview", link: "/tutorials/" },
{ text: "Profiling", link: "/tutorials/profiling" },
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching",
},
],
},
"/api/": {
Expand Down
3 changes: 3 additions & 0 deletions docs/src/tutorials/batching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)


1 change: 1 addition & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tutorials

- [Profiling](@ref profiling).
- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)

We are currently working on adding more tutorials to Reactant!! Please check back soon!
8 changes: 4 additions & 4 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(A::Type{<:CuTracedArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
@nospecialize(args::Vararg)
)
return A
end
Expand All @@ -767,18 +767,18 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(A::Type{<:CUDA.CuArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
@nospecialize(args::Vararg)
)
T = eltype(A)
N = ndims(A)
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
return Reactant.ConcreteRArray{T,N}
else
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers)
TT = Reactant.traced_type_inner(T, seen, mode, args...)
if TT === T
return A
else
return Array{Reactant.traced_type_inner(T, seen, mode, track_numbers),N}
return Array{Reactant.traced_type_inner(T, seen, mode, args...),N}
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions ext/ReactantOffsetArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(OA::Type{<:OffsetArray}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type = Union{})
@nospecialize(args::Vararg)
)
N = ndims(OA)
T = OffsetArrays.parenttype(OA)
T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers)
T2 = Reactant.traced_type_inner(T, seen, mode, args...)
return OffsetArray{eltype(T2),N,T2}
end

Expand Down
233 changes: 232 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ..Reactant:
RNumber,
MissingTracedValue,
unwrapped_eltype
using Functors: fmap
using Functors: Functors, fmap

function mlir_type(x::Union{RNumber,RArray})
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(unwrapped_eltype(x)))
Expand Down Expand Up @@ -1967,4 +1967,235 @@ end
return corrected_traced_results
end

"""
batch(
inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}},
output_types::Vector{<:MLIR.IR.Type},
batch_shape::Vector{Int64};
fn,
location=mlir_stacktrace("batch", @__FILE__, @__LINE__),
)

Generates a Reactant.MLIR.Dialects.enzyme.batch operation. It is recommended to use
`Ops.batch(f, args, batch_dims, result_dims)` or `Ops.elem_apply(f, args...)` instead
of calling this directly.

!!! warning

This function batches the inputs based on the starting dimensions of the inputs. This
aligns with the default ordering in Python frameworks like JAX and PyTorch, but is
opposite to the default ordering in Julia.
"""
@noinline function batch(
inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}},
output_types::Vector{<:MLIR.IR.Type},
batch_shape::Vector{Int64};
fn,
location=mlir_stacktrace("batch", @__FILE__, @__LINE__),
)
op = MLIR.Dialects.enzyme.batch(
[i isa TracedRArray ? i.mlir_data : i for i in inputs];
outputs=output_types,
fn=MLIR.IR.FlatSymbolRefAttribute(
String(Reactant.TracedUtils.get_attribute_by_name(fn, "sym_name"))
),
batch_shape=MLIR.IR.DenseArrayAttribute(batch_shape),
location,
)

return [
TracedRArray{MLIR.IR.julia_type(eltype(out_type)),ndims(out_type)}(
(), MLIR.IR.result(op, i), size(out_type)
) for (i, out_type) in enumerate(output_types)
]
end

# This function assumes that the last dimension of each element is the batch dimension by
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
"""
batch(f, args...; batch_dims=nothing, result_dims=nothing)

Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
familiar with `jax`, this operation corresponds to `jax.vmap`.)

If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.

To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.

## Examples

For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.

!!! danger

Mutation inside a batched function is not supported yet and will lead to unexpected results.
"""
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
batch_sizes = Int64[]
batching_dims = if batch_dims === nothing
fmap(args) do x
tmp = ndims(x)
push!(batch_sizes, size(x, tmp))
return tmp
end
else
fmap(args, batch_dims) do x, dim
dim !== nothing && push!(batch_sizes, size(x, dim))
@assert dim isa Integer || dim === nothing
dim
end
end

batch_sizes_no_ones = filter(x -> x != 1, batch_sizes)
@assert allequal(batch_sizes) "batching dimensions must be equal"
B = length(batch_sizes_no_ones) == 0 ? 1 : first(batch_sizes_no_ones)

corrected_args = fmap(args, batching_dims) do arg, dim
if dim === nothing # repeat the input along dim=0
return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...])
end
if size(arg, dim) == 1 && size(arg, dim) != B # If batch_dim is 1, then expand that dim
new_dims = collect(Int64, size(arg))
new_dims[dim] = B
arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims)
end
order = collect(Int64, 1:ndims(arg))
order[dim] = 1
order[1] = dim
return permutedims(arg, order) # Ensure batch dim is moved to the first position
end

results = batch_internal(f, corrected_args...)

if result_dims === nothing
return fmap(results) do result
order = Int64[2:ndims(result)..., 1]
return permutedims(result, order)
end
end

return fmap(results, result_dims) do result, dim
@assert dim !== nothing "Result batch dimension cannot be `nothing`"

order = collect(Int64, 1:ndims(result))
order[dim] = 1
order[1] = dim
return permutedims(result, order)
end
end

"""
elem_apply(f, args...)

This is equivalent to `f.(args...)` but generates optimized code using
Reactant.MLIR.Dialects.enzyme.batch.
"""
@noinline function elem_apply(f, args::Vararg)
return batch_internal(f, args...; batchmode=Reactant.BatchScalar)
end

@noinline function elem_apply(
::Type{T}, x::TracedRArray{T}
) where {T<:Reactant.ReactantPrimitive}
return x
end

@noinline function elem_apply(
::Type{T}, x::TracedRArray
) where {T<:Reactant.ReactantPrimitive}
# Special Path to prevent going down a despecialized path
return elem_apply(Reactant.TracedUtils.TypeCast{T}(), x)
end

@noinline function batch_internal(f, args::Vararg; batchmode=Reactant.BatchArray)
@assert batchmode != Reactant.BatchNone

if batchmode == Reactant.BatchScalar
if all(iszero ∘ ndims, args)
scalar_args = map(args) do arg
return Reactant.TracedUtils.promote_to(
TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg
)
end
return Reactant.call_with_reactant(f, scalar_args...)
end
end

fnwrap, func2, _, result, seen_args, _, linear_args, _, linear_results = Reactant.TracedUtils.make_mlir_fn(
f,
args,
(),
string(f) * (batchmode == Reactant.BatchArray ? "_batch" : "_broadcast_scalar"),
false;
batchmode,
no_args_in_result=batchmode == Reactant.BatchScalar,
do_transpose=false,
)

if batchmode == Reactant.BatchArray
batch_sizes = [size(k, 1) for k in keys(seen_args) if k isa Reactant.TracedType]
@assert allequal(batch_sizes) "batching dimensions must be equal"
B = first(batch_sizes)
else
input_shapes = [size(k) for k in keys(seen_args) if k isa Reactant.TracedType]
@assert allequal(input_shapes) "input shapes are $(input_shapes)"
output_shape = first(input_shapes)
end

batch_inputs = MLIR.IR.Value[]
for a in linear_args
idx, path = Reactant.TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
Reactant.TracedUtils.push_val!(batch_inputs, f, path[3:end])
else
fnwrap && (idx -= 1)
Reactant.TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
end
end

res = batch(
batch_inputs,
[
MLIR.IR.TensorType(
batchmode == Reactant.BatchArray ? (B, size(arg)...) : output_shape,
MLIR.IR.Type(Reactant.unwrapped_eltype(arg)),
) for arg in linear_results
],
batchmode == Reactant.BatchArray ? Int64[B] : collect(Int64, output_shape);
fn=func2,
)

residx = 1
for a in linear_results
if Reactant.TracedUtils.has_residx(a)
path = Reactant.TracedUtils.get_residx(a)
Reactant.TracedUtils.set!(result, path[2:end], res[residx])
residx += 1
else
idx, path = Reactant.TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
Reactant.TracedUtils.set!(f, path[3:end], res[residx])
residx += 1
else
fnwrap && (idx -= 1)
Reactant.TracedUtils.set!(args[idx], path[3:end], res[residx])
residx += 1
end
end
end

traced2_result = Reactant.make_tracer(
Reactant.OrderedIdDict(),
result,
(),
Reactant.TracedSetPath;
tobatch=batchmode == Reactant.BatchArray ? (B,) : output_shape,
batchmode,
)
func2.operation = MLIR.API.MlirOperation(C_NULL)

return traced2_result
end

end # module Ops
6 changes: 3 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)

res = TracedUtils.promote_to(
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
Ops.elem_apply(bc.f, args...),
Comment on lines 551 to +552
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
Ops.elem_apply(bc.f, args...),
TracedRArray{unwrapped_eltype(dest),ndims(dest)}, Ops.elem_apply(bc.f, args...)

)
TracedUtils.set_mlir_data!(dest, res.mlir_data)
return dest
Expand All @@ -563,8 +563,8 @@ function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted)

args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = TracedUtils.elem_apply(bc.f, args...)
for I in 1:length(dest)
res = Ops.elem_apply(bc.f, args...)
for I in eachindex(dest)
dest[I] = Reactant.@allowscalar res[I]
end
return dest
Expand Down
Loading
Loading