Skip to content

Commit

Permalink
feat: support dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 15, 2025
1 parent 006b1d0 commit c2b8e46
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1559,8 +1559,26 @@ end

# XXX: Support linearization and de-linearization
# XXX: some of the args are not batched (use nothing)
# XXX: Arbitrary dimensions for batching
# XXX: Out-axis
function batch(
f, args::Vector{<:TracedRArray}, batch_dims::Vector{Int}, result_dims::Vector{Int}
)
@assert length(batch_dims) == length(args)
args = map(zip(args, batch_dims)) do (arg, dim)
order = collect(1:ndims(arg))
order[dim] = 1
order[1] = dim
return permutedims(arg, order)
end
results = batch(f, args)
@assert length(results) == length(result_dims)
return map(zip(results, result_dims)) do (result, dim)
order = collect(1:ndims(result))
order[dim] = 1
order[1] = dim
return permutedims(result, order)
end
end

function batch(f, args::Vector{<:TracedRArray})
batch_sizes = [size(x, 1) for x in args]
@assert allequal(batch_sizes) "batching dimensions must be equal"
Expand Down Expand Up @@ -1625,8 +1643,8 @@ function batch(f, args::Vector{<:TracedRArray})

batch_inputs = [x.mlir_data for x in args]
out_tys = [
MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r)))
for r in result
MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r))) for
r in result
]

op = MLIR.Dialects.enzyme.batch(
Expand Down

0 comments on commit c2b8e46

Please sign in to comment.