Skip to content

Commit

Permalink
Refactor getindex to not mutate
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Nov 9, 2023
1 parent 22c341f commit 869ed56
Showing 1 changed file with 171 additions and 96 deletions.
267 changes: 171 additions & 96 deletions src/blocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ end
# getindex and setindex! using Vararg{Int,N} signature is needed for the AbstractArray
# interface, manually dispatch through to CartesianIndex{N} signature to work with Dict.

Base.delete!(t::BlockTensorMap, I::CartesianIndex) = delete!(t.data, I)

@inline function Base.get(t::BlockTensorMap, I::CartesianIndex)
@boundscheck checkbounds(t, I)
return get(t.data, I) do
return TensorMap(zeros, scalartype(t), getsubspace(space(t), I))
end
end

@inline function Base.get!(t::BlockTensorMap, I::CartesianIndex)
@boundscheck checkbounds(t, I)
return get!(t.data, I) do
return TensorMap(zeros, scalartype(t), getsubspace(space(t), I))
end
end

@propagate_inbounds function Base.getindex(t::BlockTensorArray{T,N},
I::Vararg{Int,N}) where {T,N}
return getindex(t, CartesianIndex(I))
end
@inline function Base.getindex(t::BlockTensorArray{T,N},
I::CartesianIndex{N}) where {T,N}
@boundscheck checkbounds(t, I)
# "hide" creation behind a function call such that it is only executed when needed
return get!(t.data, I) do
return TensorMap(zeros, scalartype(T), getsubspace(space(t), I))
end
return get(t, I)
end

@propagate_inbounds function Base.setindex!(t::BlockTensorArray{T,N}, v,
Expand All @@ -93,7 +105,7 @@ end
# specialisations to have scalar indexing return a TensorMap
# while non-scalar indexing yields a BlockTensorMap

_newindex(i::Int, range::Int) = i == range ? () : nothing
_newindex(i::Int, range::Int) = i == range ? (1,) : nothing
function _newindex(i::Int, range::AbstractVector{Int})
k = findfirst(==(i), range)
return k === nothing ? nothing : (k,)
Expand All @@ -112,7 +124,6 @@ function Base._unsafe_getindex(::IndexCartesian, t::BlockTensorArray{T,N},
indices = Base.to_indices(t, I)
shape = length.(Base.index_shape(indices...))
# size(dest) == shape || Base.throw_checksize_error(dest, shape)

for (k, v) in nonzero_pairs(t)
newI = _newindices(k.I, indices)
if newI !== nothing
Expand Down Expand Up @@ -180,38 +191,105 @@ nonzero_length(a::BlockTensorMap) = length(a.data)
# Show
# ----
function Base.show(io::IO, ::MIME"text/plain", x::BlockTensorMap)
xnnz = nonzero_length(x)
print(io, Base.join(size(x), "×"), " ", typeof(x), " with ", xnnz, " stored ",
xnnz == 1 ? "entry" : "entries")
if xnnz != 0
println(io, ":")
show(IOContext(io, :typeinfo => eltype(x)), x)
compact = get(io, :compact, false)::Bool
nnz = nonzero_length(x)
print(io, Base.join(size(x), "×"), " BlockTensorMap(", space(x), ")")
if !compact && nnz != 0
println(io, " with ", nnz, " stored entr", nnz == 1 ? "y" : "ies", ":")
show_braille(io, x)
end
return nothing
end
Base.show(io::IO, x::BlockTensorMap) = show(convert(IOContext, io), x)
function Base.show(io::IOContext, x::BlockTensorMap)
nzind = nonzero_keys(x)
if isempty(nzind)
return show(io, MIME("text/plain"), x)
function Base.show(io::IO, x::BlockTensorMap)
compact = get(io, :compact, false)::Bool
nnz = nonzero_length(x)
print(io, Base.join(size(x), "×"), " BlockTensorMap(", space(x), ")")
if !compact && nnz != 0
println(io, " with ", nnz, " stored entr", nnz == 1 ? "y" : "ies", ":")
show_elements(io, x)
end
return nothing
end

function show_elements(io::IO, x::BlockTensorMap)
nzind = nonzero_keys(x)
length(nzind) == 0 && return nothing
limit = get(io, :limit, false)::Bool
compact = get(io, :compact, true)::Bool
half_screen_rows = limit ? div(displaysize(io)[1] - 8, 2) : typemax(Int)
pads = map(1:ndims(x)) do i
return ndigits(maximum(getindex.(nzind, i)))
end
if !haskey(io, :compact)
io = IOContext(io, :compact => true)
end
for (k, (ind, val)) in enumerate(nonzero_pairs(x))
io = IOContext(io, :compact => compact)
nz_pairs = sort(collect(nonzero_pairs(x)); by=first)
for (k, (ind, val)) in enumerate(nz_pairs)
if k < half_screen_rows || k > length(nzind) - half_screen_rows
print(io, " ", '[', Base.join(lpad.(Tuple(ind), pads), ","), "] = ", val)
k != length(nzind) && println(io)
println(io, " ", '[', Base.join(lpad.(Tuple(ind), pads), ","), "] = ", val)
elseif k == half_screen_rows
println(io, " ", Base.join(" " .^ pads, " "), " \u22ee")
end
end
end

# adapted from SparseArrays.jl
const brailleBlocks = UInt16['', '', '', '', '', '', '', '']
function show_braille(io::IO, x::BlockTensorMap)
m = prod(getindices(size(x), codomainind(x)))
n = prod(getindices(size(x), domainind(x)))
reshape_helper = reshape(CartesianIndices((m, n)), size(x))

# The maximal number of characters we allow to display the matrix
local maxHeight::Int, maxWidth::Int
maxHeight = displaysize(io)[1] - 4 # -4 from [Prompt, header, newline after elements, new prompt]
maxWidth = displaysize(io)[2] ÷ 2

if get(io, :limit, true) && (m > 4maxHeight || n > 2maxWidth)
s = min(2maxWidth / n, 4maxHeight / m)
scaleHeight = floor(Int, s * m)
scaleWidth = floor(Int, s * n)
else
scaleHeight = m
scaleWidth = n
end

# Make sure that the matrix size is big enough to be able to display all
# the corner border characters
if scaleHeight < 8
scaleHeight = 8
end
if scaleWidth < 4
scaleWidth = 4
end

brailleGrid = fill(UInt16(10240), (scaleWidth - 1) ÷ 2 + 4, (scaleHeight - 1) ÷ 4 + 1)
brailleGrid[1, :] .= ''
brailleGrid[end - 1, :] .= ''
brailleGrid[1, 1] = ''
brailleGrid[1, end] = ''
brailleGrid[end - 1, 1] = ''
brailleGrid[end - 1, end] = ''
brailleGrid[end, :] .= '\n'

rowscale = max(1, scaleHeight - 1) / max(1, m - 1)
colscale = max(1, scaleWidth - 1) / max(1, n - 1)


for I′ in nonzero_keys(x)
I = reshape_helper[I′]
si = round(Int, (I[1] - 1) * rowscale + 1)
sj = round(Int, (I[2] - 1) * colscale + 1)

k = (sj - 1) ÷ 2 + 2
l = (si - 1) ÷ 4 + 1
p = ((sj - 1) % 2) * 4 + ((si - 1) % 4 + 1)

brailleGrid[k, l] |= brailleBlocks[p]
end

foreach(c -> print(io, Char(c)), @view brailleGrid[1:(end - 1)])
return nothing
end

# Converters
# ----------

Expand All @@ -225,7 +303,7 @@ function Base.convert(::Type{BlockTensorMap{S,N₁,N₂,T₁,N}}, t::BlockTensor
end

function Base.convert(::Type{BlockTensorMap}, t::AbstractTensorMap{S,N₁,N₂}) where {S,N₁,N₂}
tdst = BlockTensorMap{S,N₁,N₂,typeof(t)}(undef, convert(ProductSumSpace, codomain(t)), convert(ProductSumSpace, domain(t)))
tdst = BlockTensorMap{S,N₁,N₂,typeof(t)}(undef, convert(ProductSumSpace{S,N₁}, codomain(t)), convert(ProductSumSpace{S,N₂}, domain(t)))
tdst[1] = t
return tdst
end
Expand Down Expand Up @@ -421,7 +499,7 @@ function VI.scale!(ty::BlockTensorMap, tx::BlockTensorMap,
end
# in-place scale elements from tx (getindex might allocate!)
for (I, v) in nonzero_pairs(tx)
scale!(ty[I], v, α)
ty[I] = scale!(ty[I], v, α)
end
return ty
end
Expand All @@ -438,7 +516,7 @@ function VI.add(y::BlockTensorMap, x::BlockTensorMap, α::Number,
β::Number)
space(y) == space(x) || throw(SpaceMisMatch())
T = VI.promote_add(y, x, α, β)
z = zerovector(t1, T)
z = zerovector(y, T)
# TODO: combine these operations where possible
scale!(z, y, β)
return add!(z, x, α)
Expand All @@ -449,7 +527,7 @@ function VI.add!(y::BlockTensorMap, x::BlockTensorMap, α::Number,
# TODO: combine these operations where possible
scale!(y, β)
for (k, v) in nonzero_pairs(x)
add!(y[k], v, α)
y[k] = add!(y[k], v, α)
end
return y
end
Expand All @@ -468,7 +546,7 @@ function VI.inner(x::BlockTensorMap, y::BlockTensorMap)
end

# TODO: this is type-piracy!
VI.scalartype(::Type{Union{A,B}}) where {A,B} = Union{scalartype(A), scalartype(B)}
# VI.scalartype(::Type{Union{A,B}}) where {A,B} = Union{scalartype(A), scalartype(B)}

# TensorOperations
# ----------------
Expand Down Expand Up @@ -526,7 +604,7 @@ function TO.tensoradd!(C::BlockTensorMap{S}, pC::Index2Tuple,
indCinA = linearize(pC)
for (IA, v) in nonzero_pairs(A)
IC = CartesianIndex(TupleTools.getindices(IA.I, indCinA))
tensoradd!(C[IC], pC, v, conjA, α, One())
C[IC] = tensoradd!(C[IC], pC, v, conjA, α, One())
end
return C
end
Expand All @@ -538,82 +616,79 @@ function TO.tensorcontract!(C::BlockTensorMap{S}, pC::Index2Tuple,
argcheck_tensorcontract(parent(C), pC, parent(A), pA, parent(B), pB)
dimcheck_tensorcontract(parent(C), pC, parent(A), pA, parent(B), pB)

try
scale!(C, β)
scale!(C, β)

keysA = sort!(collect(nonzero_keys(A)); by=IA -> CartesianIndex(getindices(IA.I, pA[2])))
keysB = sort!(collect(nonzero_keys(B));
by=IB -> CartesianIndex(getindices(IB.I, pB[1])))

iA = iB = 1
@inbounds while iA <= length(keysA) && iB <= length(keysB)
IA = keysA[iA]
IB = keysB[iB]
IAc = CartesianIndex(getindices(IA.I, pA[2]))
IBc = CartesianIndex(getindices(IB.I, pB[1]))
if IAc == IBc
Ic = IAc
jA = iA
while jA < length(keysA)
if CartesianIndex(getindices(keysA[jA + 1].I, pA[2])) == Ic
jA += 1
else
break
end
end
jB = iB
while jB < length(keysB)
if CartesianIndex(getindices(keysB[jB + 1].I, pB[1])) == Ic
jB += 1
else
break
end
keysA = sort!(collect(nonzero_keys(A));
by=IA -> CartesianIndex(getindices(IA.I, pA[2])))
keysB = sort!(collect(nonzero_keys(B));
by=IB -> CartesianIndex(getindices(IB.I, pB[1])))

iA = iB = 1
@inbounds while iA <= length(keysA) && iB <= length(keysB)
IA = keysA[iA]
IB = keysB[iB]
IAc = CartesianIndex(getindices(IA.I, pA[2]))
IBc = CartesianIndex(getindices(IB.I, pB[1]))
if IAc == IBc
Ic = IAc
jA = iA
while jA < length(keysA)
if CartesianIndex(getindices(keysA[jA + 1].I, pA[2])) == Ic
jA += 1
else
break
end
rA = iA:jA
rB = iB:jB
if length(rA) < length(rB)
for kB in rB
IB = keysB[kB]
IBo = CartesianIndex(getindices(IB.I, pB[2]))
vB = B[IB]
for kA in rA
IA = keysA[kA]
IAo = CartesianIndex(getindices(IA.I, pA[1]))
IABo = CartesianIndex(IAo, IBo)
IC = CartesianIndex(getindices(IABo.I, linearize(pC)))
vA = A[IA]
tensorcontract!(C[IC], pC, vA, pA, conjA, vB, pB, conjB, α, One())
end
end
end
jB = iB
while jB < length(keysB)
if CartesianIndex(getindices(keysB[jB + 1].I, pB[1])) == Ic
jB += 1
else
break
end
end
rA = iA:jA
rB = iB:jB
if length(rA) < length(rB)
for kB in rB
IB = keysB[kB]
IBo = CartesianIndex(getindices(IB.I, pB[2]))
vB = B[IB]
for kA in rA
IA = keysA[kA]
IAo = CartesianIndex(getindices(IA.I, pA[1]))
IABo = CartesianIndex(IAo, IBo)
IC = CartesianIndex(getindices(IABo.I, linearize(pC)))
vA = A[IA]
for kB in rB
IB = keysB[kB]
IBo = CartesianIndex(getindices(IB.I, pB[2]))
vB = parent(B).data[IB]
IABo = CartesianIndex(IAo, IBo)
IC = CartesianIndex(getindices(IABo.I, linearize(pC)))
tensorcontract!(C[IC], pC, vA, pA, conjA, vB, pB, conjB, α, One())
end
C[IC] = tensorcontract!(C[IC], pC, vA, pA, conjA, vB, pB, conjB, α,
One())
end
end
iA = jA + 1
iB = jB + 1
elseif IAc < IBc
iA += 1
else
iB += 1
for kA in rA
IA = keysA[kA]
IAo = CartesianIndex(getindices(IA.I, pA[1]))
vA = A[IA]
for kB in rB
IB = keysB[kB]
IBo = CartesianIndex(getindices(IB.I, pB[2]))
vB = parent(B).data[IB]
IABo = CartesianIndex(IAo, IBo)
IC = CartesianIndex(getindices(IABo.I, linearize(pC)))
C[IC] = tensorcontract!(C[IC], pC, vA, pA, conjA, vB, pB, conjB, α,
One())
end
end
end
iA = jA + 1
iB = jB + 1
elseif IAc < IBc
iA += 1
else
iB += 1
end
catch ME
println("C", space(C), " ", pC)
println("A", space(A), " ", pA, " ", conjA)
println("B", space(B), " ", pB, " ", conjB)
rethrow(ME)
end

return C
end

Expand All @@ -632,7 +707,7 @@ function TO.tensortrace!(C::BlockTensorMap{S}, pC::Index2Tuple,
IAc1 == IAc2 || continue

IC = CartesianIndex(getindices(IA.I, linearize(pC)))
tensortrace!(C[IC], pC, v, pA, conjA, α, one(β))
C[IC] = tensortrace!(C[IC], pC, v, pA, conjA, α, one(β))
end
return C
end
Expand Down Expand Up @@ -716,7 +791,7 @@ end

# TODO: similar for tensoradd!, tensortrace!

Base.haskey(t::BlockTensorMap, I::CartesianIndex) = haskey(parent(t).data, I)
Base.haskey(t::BlockTensorMap, I::CartesianIndex) = haskey(t.data, I)
function Base.haskey(t::BlockTensorMap, i::Int)
return haskey(parent(t).data, CartesianIndices(t)[i])
return haskey(t.data, CartesianIndices(t)[i])
end

0 comments on commit 869ed56

Please sign in to comment.