Skip to content

Commit

Permalink
feat: add support for the remaining wrapper types (#369)
Browse files Browse the repository at this point in the history
* feat: add materialize_traced_array for all other wrappers

* refactor: use scatter for generating diagm

* refactor: directly generate the region for simple_scatter_op

* feat: generalize diagm

* feat: efficient non-contiguous setindex

* fix: non-contiguous indexing is now supported

* feat: implement set_mlir_data for the remaining types

* refactor: use `Ops.gather_getindex` to implement diag

* fix: noinline ops

* fix: incorrect rebase

* fix: dispatches

* fix: diagm for repeated indices and initial tests

* fix: higher dimensional indexing + tests

* fix: matrix multiplication of wrapper types

* fix: de-specialize 3 arg mul!
avik-pal authored Dec 29, 2024
1 parent d4e7c76 commit 8e4c095
Showing 10 changed files with 565 additions and 142 deletions.
5 changes: 4 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
@@ -41,6 +41,9 @@ function create_result(tocopy::T, path, result_stores) where {T}
elems = Union{Symbol,Expr}[]

for i in 1:fieldcount(T)
# If the field is undefined we don't set it. A common example for this is `du2`
# for Tridiagonal
isdefined(tocopy, i) || continue
ev = create_result(getfield(tocopy, i), append_path(path, i), result_stores)
push!(elems, ev)
end
@@ -102,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic
end

function create_result(
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol},
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char},
path,
result_stores,
)
110 changes: 110 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
@@ -1418,4 +1418,114 @@ julia> Reactant.@jit(
end
end

"""
scatter_setindex(dest, scatter_indices, updates)
Uses [`MLIR.Dialects.stablehlo.scatter`](@ref) to set the values of `dest` at the indices
specified by `scatter_indices` to the values in `updates`. If the indices are contiguous it
is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref)
instead.
"""
@noinline function scatter_setindex(
dest::TracedRArray{T,N},
scatter_indices::TracedRArray{Int64,2},
updates::TracedRArray{T,1},
) where {T,N}
@assert length(updates) == size(scatter_indices, 1)
@assert size(scatter_indices, 2) == N

update_computation = MLIR.IR.Region()
block = MLIR.IR.Block(
[mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})],
[MLIR.IR.Location(), MLIR.IR.Location()],
)
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
MLIR.IR.rmfromparent!(return_op)
push!(block, return_op)
pushfirst!(update_computation, block)

#! format: off
update_window_dims = Int64[]
inserted_window_dims = collect(Int64, 0:(N - 1))
input_batching_dims = Int64[]
scatter_indices_batching_dims = Int64[]
scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1))
index_vector_dim = Int64(1)

scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
MLIR.IR.context(),
length(update_window_dims), update_window_dims,
length(inserted_window_dims), inserted_window_dims,
length(input_batching_dims), input_batching_dims,
length(scatter_indices_batching_dims), scatter_indices_batching_dims,
length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims,
index_vector_dim,
)
#! format: on

return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.scatter(
[dest.mlir_data],
scatter_indices.mlir_data,
[updates.mlir_data];
result_0=[mlir_type(TracedRArray{T,N}, size(dest))],
update_computation,
scatter_dimension_numbers,
),
1,
),
size(dest),
)
end

"""
gather_getindex(src, gather_indices)
Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices
specified by `gather_indices`. If the indices are contiguous it is recommended to directly
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
"""
@noinline function gather_getindex(
src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2}
) where {T,N}
@assert size(gather_indices, 2) == N

#! format: off
offset_dims = Int64[1]
collapsed_slice_dims = collect(Int64, 0:(N - 2))
operand_batching_dims = Int64[]
start_indices_batching_dims = Int64[]
start_index_map = collect(Int64, 0:(N - 1))
index_vector_dim = Int64(1)

dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(length(offset_dims)), offset_dims,
Int64(length(collapsed_slice_dims)), collapsed_slice_dims,
Int64(length(operand_batching_dims)), operand_batching_dims,
Int64(length(start_indices_batching_dims)), start_indices_batching_dims,
Int64(length(start_index_map)), start_index_map,
Int64(index_vector_dim),
)
#! format: on

return reshape(
TracedRArray{T}(
MLIR.IR.result(
MLIR.Dialects.stablehlo.gather(
src.mlir_data,
gather_indices.mlir_data;
dimension_numbers,
slice_sizes=fill(Int64(1), N),
indices_are_sorted=false,
),
1,
),
),
size(gather_indices, 1),
)
end

end # module Ops
27 changes: 27 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
@@ -115,3 +115,30 @@ for randfun in (:rand, :randn, :randexp)
# end
end
end

# LinearAlgebra.jl overloads
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
## without specializing on every possible combination of types
for (cT, aT, bT) in (
(:AbstractVector, :AbstractMatrix, :AbstractVector),
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
)
@eval begin
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
end
return C
end

# Needed mostly for 1.10 where 3-arg mul is often specialized
@reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT)
call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false)
return C
end
end
end
20 changes: 14 additions & 6 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
) where {T,N}
shape = Tuple(shape)
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == shape
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
end
return new{T,N}(paths, mlir_data, shape)
end
@@ -119,15 +119,23 @@ const WrappedTracedRArray{T,N} = WrappedArray{
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
const AnyTracedRMatrix{T} = Union{
AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}}
AnyTracedRArray{T,2},
LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}},
LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}},
}
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}

function TracedRArray(data::MLIR.IR.Value)
function TracedRArray{T}(data::MLIR.IR.Value) where {T}
data_type = MLIR.IR.type(data)
return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}(
(), data, size(data_type)
)
if T == eltype(MLIR.IR.julia_type(data_type))
return TracedRArray{T,ndims(data_type)}((), data, size(data_type))
end
tdata = TracedRArray(data)
return Ops.convert(TracedRArray{T,ndims(data_type)}, tdata)
end

function TracedRArray(data::MLIR.IR.Value)
return TracedRArray{eltype(MLIR.IR.julia_type(MLIR.IR.type(data)))}(data)
end

struct XLAArray{T,N} <: RArray{T,N} end
74 changes: 54 additions & 20 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
@@ -14,8 +14,9 @@ using ..Reactant:
MLIR,
ancestor,
unwrapped_eltype
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

using ReactantCore: ReactantCore
using ..TracedUtils: TracedUtils, materialize_traced_array
using GPUArraysCore: GPUArraysCore

ReactantCore.is_traced(::TracedRArray) = true
@@ -55,25 +56,37 @@ function Base.getindex(
return TracedRNumber{T}((), res2)
end

function Base.getindex(a::TracedRArray{T,0}) where {T}
return TracedRNumber{T}((), a.mlir_data)
end
Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)

# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

foreach(indices) do idxs
idxs isa Number && return nothing
non_contiguous_getindex = false
for idxs in indices
idxs isa Number && continue
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(a) <: Bool
contiguous || error("non-contiguous indexing is not supported")
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_getindex = true
break
end
end

if non_contiguous_getindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.gather_getindex(a, indices)
return Ops.reshape(res, size(indices_tuples)...)
end

start_indices = map(indices) do i
@@ -99,16 +112,41 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
end

function Base.setindex!(
a::TracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

non_contiguous_setindex = false
for idxs in indices
idxs isa Number && continue
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_setindex = true
break
end
end

if non_contiguous_setindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
a.mlir_data = res.mlir_data
return v
end

v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)

indices = [
(
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
@@ -124,11 +162,7 @@ function Base.setindex!(
return v
end

function Base.setindex!(
a::AnyTracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
function Base.setindex!(a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...)
setindex!(ancestor(a), v, ancestor_indices...)
return a
62 changes: 10 additions & 52 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
# within compilation. However, it means these functions are a _lot_ faster to compile.
module TracedUtils

using LinearAlgebra: LinearAlgebra
using Adapt: Adapt, WrappedReshapedArray
using ..Reactant:
Reactant,
@@ -19,34 +18,20 @@ using ..Reactant:
Ops

materialize_traced_array(x::TracedRArray) = x

materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]

function materialize_traced_array(
x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}
) where {T,N,M}
return Ops.reshape(materialize_traced_array(parent(x)), size(x)...)
end
function materialize_traced_array(
x::LinearAlgebra.Transpose{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
px = parent(x)
A = ndims(px) == 1 ? reshape(px, :, 1) : px
return permutedims(A, (2, 1))
end
function materialize_traced_array(
x::LinearAlgebra.Adjoint{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
return conj(materialize_traced_array(transpose(parent(x))))
end

function materialize_traced_array(
x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}
) where {T,N,perm,iperm}
return permutedims(parent(x), perm)
end
function materialize_traced_array(
x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return LinearAlgebra.diagm(parent(x))
end

get_mlir_data(x::TracedRNumber) = x.mlir_data
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x)
@@ -58,51 +43,24 @@ function set_mlir_data!(x::TracedRArray, data)
x.mlir_data = data
return x
end

function set_mlir_data!(
x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, data
) where {T,N,M}
res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data
res_mlir_data = Ops.reshape(TracedRArray{T}(data), size(parent(x))...).mlir_data
set_mlir_data!(parent(x), res_mlir_data)
return x
end
function set_mlir_data!(
x::LinearAlgebra.Transpose{TracedRNumber{T},TracedRArray{T,N}}, data
) where {T,N}
tdata = TracedRArray(data)
px = parent(x)
px.mlir_data = (
if ndims(px) == 1
Ops.reshape(tdata, length(tdata))
else
Ops.transpose(tdata, [2, 1])
end
).mlir_data
return x
end
function set_mlir_data!(
x::LinearAlgebra.Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data
) where {T,N}
tdata = TracedRArray(data)
px = parent(x)
transposed_data =
ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1])
px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data
return x
end

function set_mlir_data!(
x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}, data
) where {T,N,perm,iperm}
parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data
parent(x).mlir_data = permutedims(TracedRArray{T}(data), iperm).mlir_data
return x
end
function set_mlir_data!(
x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data
return x
end
function set_mlir_data!(x::AnyTracedRArray, data)
setindex!(x, TracedRArray(data), axes(x)...)

function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
setindex!(x, TracedRArray{T}(data), axes(x)...)
return x
end

264 changes: 208 additions & 56 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,170 @@
module TracedLinearAlgebra

using ..Reactant
import ..TracedRArray
import ..TracedRNumber
import ..AnyTracedRArray
import ..AnyTracedRMatrix
import ..AnyTracedRVector

import ..TracedUtils
using ..TracedUtils: get_mlir_data, materialize_traced_array, set_mlir_data!

import ..Ops
import ..MLIR
using ..Reactant:
TracedRArray,
TracedRNumber,
AnyTracedRArray,
AnyTracedRMatrix,
AnyTracedRVector,
Ops,
MLIR

using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data!

using LinearAlgebra

function LinearAlgebra.mul!(
# Various Wrapper Arrays defined in LinearAlgebra
function TracedUtils.materialize_traced_array(
x::Transpose{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
px = parent(x)
A = ndims(px) == 1 ? reshape(px, :, 1) : px
return permutedims(A, (2, 1))
end

function TracedUtils.materialize_traced_array(
x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
return conj(materialize_traced_array(transpose(parent(x))))
end

function TracedUtils.materialize_traced_array(
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return diagm(parent(x))
end

function TracedUtils.materialize_traced_array(
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
end

for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
uAT = Symbol(:Unit, AT)
@eval begin
function TracedUtils.materialize_traced_array(
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(comp))
return Ops.select(indicator, parent(x), zero(parent(x)))
end

function TracedUtils.materialize_traced_array(
x::$(uAT){TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
nondiag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="NE")
x = materialize_traced_array($(AT)(parent(x)))
return Ops.select(nondiag_indicator, x, one.(x))
end
end
end

function TracedUtils.materialize_traced_array(
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
if x.uplo == 'L'
indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="GT")
x_lt = Ops.select(indicator, parent(x), zero(parent(x)))
x_ltd = materialize_traced_array(LowerTriangular(parent(x)))
return Ops.add(x_lt, Ops.transpose(x_ltd, [2, 1]))
else
indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="LT")
x_ut = Ops.select(indicator, parent(x), zero(parent(x)))
x_utd = materialize_traced_array(UpperTriangular(parent(x)))
return Ops.add(Ops.transpose(x_utd, [2, 1]), x_ut)
end
end

function TracedUtils.set_mlir_data!(
x::Transpose{TracedRNumber{T},TracedRArray{T,N}}, data
) where {T,N}
tdata = TracedRArray{T}(data)
px = parent(x)
px.mlir_data = (
if ndims(px) == 1
Ops.reshape(tdata, length(tdata))
else
Ops.transpose(tdata, [2, 1])
end
).mlir_data
return x
end

function TracedUtils.set_mlir_data!(
x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data
) where {T,N}
tdata = TracedRArray{T}(data)
px = parent(x)
transposed_data =
ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1])
px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data
return x
end

function TracedUtils.set_mlir_data!(
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data
return x
end

for (AT, dcomp, ocomp) in (
(:LowerTriangular, "GE", "LT"),
(:UnitLowerTriangular, "GT", "LE"),
(:UpperTriangular, "LE", "GT"),
(:UnitUpperTriangular, "LT", "GE"),
)
@eval function TracedUtils.set_mlir_data!(
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data
) where {T}
tdata = TracedRArray{T}(data)
z = zero(tdata)
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
data_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(dcomp))
original_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(ocomp))
res = Ops.add(
Ops.select(data_indicator, tdata, z), Ops.select(original_indicator, x.data, z)
)
set_mlir_data!(x.data, res.mlir_data)
return x
end
end

function TracedUtils.set_mlir_data!(
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data
) where {T}
if x.uplo == 'L'
set_mlir_data!(LowerTriangular(parent(x)), data)
else
set_mlir_data!(UpperTriangular(parent(x)), data)
end
return x
end

function TracedUtils.set_mlir_data!(
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
tdata = TracedRArray{T}(data)
set_mlir_data!(x.dl, diag(tdata, -1).mlir_data)
set_mlir_data!(x.d, diag(tdata, 0).mlir_data)
set_mlir_data!(x.du, diag(tdata, 1).mlir_data)
return x
end

# Core functions
function overloaded_mul!(
@nospecialize(C::TracedRArray{T,1}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRVector),
@@ -23,23 +173,23 @@ function LinearAlgebra.mul!(
) where {T}
# TODO: The reshape operations are not getting optimized, we should directly call dot_general
rC = Ops.reshape(C, length(C), 1)
LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β)
overloaded_mul!(rC, A, reshape(B, :, 1), α, β)
C.mlir_data = get_mlir_data(vec(rC))
return C
end

function LinearAlgebra.mul!(
function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRVector),
α::Number=true,
β::Number=false,
) where {T}
LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β)
overloaded_mul!(C, A, reshape(B, :, 1), α, β)
return C
end

function LinearAlgebra.mul!(
function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRMatrix),
@@ -119,50 +269,52 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
# <unknown>:0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64>
length(indices) 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[])

idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,2}, indices))

#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(0), Int64[],
Int64(2), Int64[0, 1],
Int64(0), Int64[],
Int64(0), Int64[],
Int64(2), Int64[0, 1],
Int64(1)
)
#! format: on

slice_sizes = get_mlir_data(
Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1])
)
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dynamic_gather(
get_mlir_data(y), idxs, slice_sizes; dimension_numbers
),
1,
)
return TracedRArray{T,1}((), res, (diag_length,))
return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices))
end

function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T}
return LinearAlgebra.diagm(length(v), length(v), v)
end
function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T}
m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check
function LinearAlgebra._diagm(
shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}...
) where {T}
m, n = LinearAlgebra.diagm_size(shape, kv...)

v = materialize_traced_array(v)
D = length(v)
row_idxs = Ops.iota(Int, [D, D]; iota_dimension=1)
col_idxs = Ops.iota(Int, [D, D]; iota_dimension=2)
diag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="EQ")
# For repeated indices we need to aggregate the values
kv_updated = Dict{Integer,AnyTracedRArray{T,1}}()
for (k, v) in kv
if haskey(kv_updated, k)
kv_updated[k] = kv_updated[k] + v
else
kv_updated[k] = v
end
end

mat = (v .+ zero(v)') .* diag_indicator
return Ops.pad(
mat,
TracedUtils.promote_to(TracedRNumber{T}, 0);
high=[m - length(v), n - length(v)],
scatter_indices = Matrix{Int64}[]
concat_inputs = MLIR.IR.Value[]
for (k, v) in pairs(kv_updated)
push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :])
push!(concat_inputs, get_mlir_data(v))
end
scatter_indices = Ops.constant(reduce(vcat, scatter_indices))
values = TracedRArray{T,1}(
(),
MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1),
(size(scatter_indices, 1),),
)
return Ops.scatter_setindex(
Ops.constant(fill(zero(T), (m, n))), scatter_indices, values
)
end

# Common Utilities
## The cartesian version doesn't exist in julia 1.10
function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
idx1, idx2 = 1 + max(0, -k), 1 + max(0, k)
L = max(0, k 0 ? min(m + k, n) : min(m, n - k))
indices = Matrix{Int}(undef, (L, 2))
for i in axes(indices, 1)
indices[i, 1] = idx1 + i - 2
indices[i, 2] = idx2 + i - 2
end
return indices
end

end
74 changes: 74 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -442,6 +442,26 @@ end
@test @allowscalar all(isone, x_ra_array[4, :])
end

function non_contiguous_setindex!(x)
x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0
return x
end

@testset "non-contiguous setindex!" begin
x = rand(6, 6)
x_ra = Reactant.to_rarray(x)

y = @jit(non_contiguous_setindex!(x_ra))
y = Array(y)
x_ra = Array(x_ra)
@test all(isone, y[1:3, 1:4])
@test all(isone, x_ra[1:3, 1:4])
@test !all(isone, y[4:end, :])
@test !all(isone, x_ra[4:end, :])
@test !all(isone, y[:, 5:end])
@test !all(isone, x_ra[:, 5:end])
end

tuple_byref(x) = (; a=(; b=x))
tuple_byref2(x) = abs2.(x), tuple_byref2(x)

@@ -717,3 +737,57 @@ end
@test res[1] isa ConcreteRArray{Float64,2}
@test res[2] isa ConcreteRNumber{Float64}
end

@testset "non-contiguous indexing" begin
x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1(x) = x[[1, 3, 2], :, :]
non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]]

@test @jit(non_contiguous_indexing1(x_ra)) non_contiguous_indexing1(x)
@test @jit(non_contiguous_indexing2(x_ra)) non_contiguous_indexing2(x)

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1(x) = x[[1, 3, 2], :]
non_contiguous_indexing2(x) = x[:, [1, 2, 2]]

@test @jit(non_contiguous_indexing1(x_ra)) non_contiguous_indexing1(x)
@test @jit(non_contiguous_indexing2(x_ra)) non_contiguous_indexing2(x)

x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2
non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2

@jit(non_contiguous_indexing1!(x_ra))
non_contiguous_indexing1!(x)
@test x_ra x

x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

@jit(non_contiguous_indexing2!(x_ra))
non_contiguous_indexing2!(x)
@test x_ra x

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2
non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2

@jit(non_contiguous_indexing1!(x_ra))
non_contiguous_indexing1!(x)
@test x_ra x

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

@jit(non_contiguous_indexing2!(x_ra))
non_contiguous_indexing2!(x)
@test x_ra x
end
41 changes: 34 additions & 7 deletions test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra, Reactant
using LinearAlgebra, Reactant, Test

function muladd2(A, x, b)
C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2))
@@ -130,15 +130,42 @@ end
@test @jit(diagm(4, 5, x_ra)) diagm(4, 5, x)
@test @jit(diagm(6, 6, x_ra)) diagm(6, 6, x)
@test_throws DimensionMismatch @jit(diagm(3, 3, x_ra))

x1 = rand(3)
x2 = rand(3)
x3 = rand(2)
x_ra1 = Reactant.to_rarray(x1)
x_ra2 = Reactant.to_rarray(x2)
x_ra3 = Reactant.to_rarray(x3)

@test @jit(diagm(1 => x_ra1)) diagm(1 => x1)
@test @jit(diagm(1 => x_ra1, -1 => x_ra3)) diagm(1 => x1, -1 => x3)
@test @jit(diagm(1 => x_ra1, 1 => x_ra2)) diagm(1 => x1, 1 => x2)
end

# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be
# optimized
# TODO: Currently <Wrapper Type>(x) * x goes down the generic matmul path but it should
# clearly be optimized
mul_diagonal(x) = Diagonal(x) * x

@testset "mul_diagonal" begin
x = rand(4)
mul_tridiagonal(x) = Tridiagonal(x) * x
mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x
mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x
mul_lower_triangular(x) = LowerTriangular(x) * x
mul_upper_triangular(x) = UpperTriangular(x) * x
mul_symmetric(x) = Symmetric(x) * x

@testset "Wrapper Types Matrix Multiplication" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(mul_diagonal(x_ra)) mul_diagonal(x)
@testset "$(wrapper_type)" for (wrapper_type, fn) in [
(Diagonal, mul_diagonal),
(Tridiagonal, mul_tridiagonal),
(UnitLowerTriangular, mul_unit_lower_triangular),
(UnitUpperTriangular, mul_unit_upper_triangular),
(LowerTriangular, mul_lower_triangular),
(UpperTriangular, mul_upper_triangular),
(Symmetric, mul_symmetric),
]
@test @jit(fn(x_ra)) fn(x)
end
end
30 changes: 30 additions & 0 deletions test/wrapped_arrays.jl
Original file line number Diff line number Diff line change
@@ -172,3 +172,33 @@ end
@test all(iszero, y_res)
end
end

function lower_triangular_write(x)
y = LowerTriangular(copy(x))
@. y *= 2
return y
end

function upper_triangular_write(x)
y = UpperTriangular(copy(x))
@. y *= 2
return y
end

function tridiagonal_write(x)
y = Tridiagonal(copy(x))
@. y *= 2
return y
end

@testset "Broadcasted Multiply and Alloate" begin
@testset "$(aType)" for (aType, fn) in [
("LowerTriangular", lower_triangular_write),
("UpperTriangular", upper_triangular_write),
("Tridiagonal", tridiagonal_write),
]
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)
@test @jit(fn(x_ra)) fn(x)
end
end

2 comments on commit 8e4c095

@wsmoses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122139

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.11 -m "<description of version>" 8e4c095a4d510366ae054127f368fdba5a92d88f
git push origin v0.2.11

Please sign in to comment.