From 4a75bf6c98f9df82eba452659b1a8c60af2cf814 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 2 Jan 2025 11:04:32 -0500 Subject: [PATCH] feat: partial support for boolean indexing (#457) --- src/TracedRArray.jl | 12 ++++++++++++ test/basic.jl | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8fb56623c..3d0cd7dfe 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -119,6 +119,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) i isa CartesianIndex && return Tuple(i) + i isa AbstractArray{<:Bool} && return findall(i) return i end @@ -137,6 +138,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} end if use_gather_getindex + # TODO: This will create a dynamically sized tensor and we need to implement + # `findall` for it. + if any(i -> unwrapped_eltype(i) <: Bool, indices) + error("Boolean indexing with TracedRArrays isn't fully supported yet.") + end indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices) indices_list = generate_index_list(indices_list...) res = Ops.gather_getindex(a, indices_list) @@ -170,6 +176,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) i isa CartesianIndex && return Tuple(i) + i isa AbstractArray{<:Bool} && return findall(i) return i end @@ -188,6 +195,11 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { end if use_scatter_setindex + # TODO: This will create a dynamically sized tensor and we need to implement + # `findall` for it. + if any(i -> unwrapped_eltype(i) <: Bool, indices) + error("Boolean indexing with TracedRArrays isn't fully supported yet.") + end indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices) indices_list = generate_index_list(indices_list...) res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v))) diff --git a/test/basic.jl b/test/basic.jl index 8fe89dbf6..ca5ce7729 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -873,3 +873,15 @@ end @test @jit(s3(x, y)) isa Any @test @jit(s4(x, y)) isa Any end + +@testset "Boolean Indexing" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) + idxs_ra = Reactant.to_rarray(rand(Bool, 16)) + + fn(x, idxs) = x[:, idxs] + + @test_throws ErrorException @jit(fn(x_ra, idxs_ra)) + + res = @jit fn(x_ra, Array(idxs_ra)) + @test res ≈ fn(Array(x_ra), Array(idxs_ra)) +end