Skip to content

Commit

Permalink
feat: partial support for boolean indexing (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jan 2, 2025
1 parent d903ec6 commit 4a75bf6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
i isa AbstractArray{<:Bool} && return findall(i)
return i
end

Expand All @@ -137,6 +138,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
end

if use_gather_getindex
# TODO: This will create a dynamically sized tensor and we need to implement
# `findall` for it.
if any(i -> unwrapped_eltype(i) <: Bool, indices)
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
end
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
indices_list = generate_index_list(indices_list...)
res = Ops.gather_getindex(a, indices_list)
Expand Down Expand Up @@ -170,6 +176,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
i isa AbstractArray{<:Bool} && return findall(i)
return i
end

Expand All @@ -188,6 +195,11 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
end

if use_scatter_setindex
# TODO: This will create a dynamically sized tensor and we need to implement
# `findall` for it.
if any(i -> unwrapped_eltype(i) <: Bool, indices)
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
end
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
indices_list = generate_index_list(indices_list...)
res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v)))
Expand Down
12 changes: 12 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,15 @@ end
@test @jit(s3(x, y)) isa Any
@test @jit(s4(x, y)) isa Any
end

@testset "Boolean Indexing" begin
x_ra = Reactant.to_rarray(rand(Float32, 4, 16))
idxs_ra = Reactant.to_rarray(rand(Bool, 16))

fn(x, idxs) = x[:, idxs]

@test_throws ErrorException @jit(fn(x_ra, idxs_ra))

res = @jit fn(x_ra, Array(idxs_ra))
@test res fn(Array(x_ra), Array(idxs_ra))
end

0 comments on commit 4a75bf6

Please sign in to comment.