Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support passing in a nothing in dims
Browse files Browse the repository at this point in the history
avik-pal committed Jan 15, 2025
1 parent c2b8e46 commit 945408c
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/Ops.jl
Original file line number Diff line number Diff line change
@@ -1557,19 +1557,30 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
)
end

# XXX: Support linearization and de-linearization
# XXX: some of the args are not batched (use nothing)
function batch(
f, args::Vector{<:TracedRArray}, batch_dims::Vector{Int}, result_dims::Vector{Int}
f,
args::Vector{<:TracedRArray},
batch_dims::Vector{Union{Int,Nothing}},
result_dims::Union{Vector{Int},Nothing}=nothing,
)
@assert length(batch_dims) == length(args)

batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)]
filter!(x -> x != 1, batch_sizes)
@assert allequal(batch_sizes) "batching dimensions must be equal"
B = length(batch_sizes) == 0 ? 1 : first(batch_sizes)

args = map(zip(args, batch_dims)) do (arg, dim)
if dim === nothing
return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...])
end
order = collect(1:ndims(arg))
order[dim] = 1
order[1] = dim
return permutedims(arg, order)
end
results = batch(f, args)
result_dims === nothing && (result_dims = ones(Int64, length(results)))
@assert length(results) == length(result_dims)
return map(zip(results, result_dims)) do (result, dim)
order = collect(1:ndims(result))

0 comments on commit 945408c

Please sign in to comment.