From 348161c44718605b1a5e00f5737156b6c54dad2c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 19 Nov 2024 08:21:31 -0500 Subject: [PATCH] more `mul!` things --- src/linalg/linalg.jl | 106 ++++++++++++++++++++++++++++--------- src/tensors/blocktensor.jl | 6 +-- 2 files changed, 83 insertions(+), 29 deletions(-) diff --git a/src/linalg/linalg.jl b/src/linalg/linalg.jl index cbdd246..6868490 100644 --- a/src/linalg/linalg.jl +++ b/src/linalg/linalg.jl @@ -66,39 +66,95 @@ function LinearAlgebra.mul!( compose(space(A), space(B)) == space(C) || throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))")) - Caxes = Base.Fix1(axes, C) - Aaxes = Base.Fix1(axes, A) - - for I in Iterators.product(map(Caxes, codomainind(C))...), - J in Iterators.product(map(Caxes, domainind(C))...) - - did_mul = false - for K in Iterators.product(map(Aaxes, domainind(A))...) - vA = get(A, CartesianIndex(I..., K...), nothing) - isnothing(vA) && continue - vB = get(B, CartesianIndex(K..., J...), nothing) - isnothing(vB) && continue - - vC = get(C, CartesianIndex(I..., J...), nothing) - if did_mul - C[I..., J...] = _mul!!(vC, vA, vB, α, One()) - else - C[I..., J...] = _mul!!(vC, vA, vB, α, β) - did_mul = true + scale!(C, β) + + sortIA(IA) = CartesianIndex(TT.getindices(IA.I, domainind(A))) + keysA = sort!(vec(collect(nonzero_keys(A))); by=sortIA) + sortIB(IB) = CartesianIndex(TT.getindices(IB.I, codomainind(B))) + keysB = sort!(vec(collect(nonzero_keys(B))); by=sortIB) + + iA = iB = 1 + @inbounds while iA <= length(keysA) && iB <= length(keysB) + IA = keysA[iA] + IB = keysB[iB] + IAc = CartesianIndex(TT.getindices(IA.I, domainind(A))) + IBc = CartesianIndex(TT.getindices(IB.I, codomainind(B))) + if IAc == IBc + Ic = IAc + jA = iA + while jA < length(keysA) + if CartesianIndex(TT.getindices(keysA[jA + 1].I, domainind(A))) == Ic + jA += 1 + else + break + end end - end - # handle `β` - if !did_mul - vC = get(C, CartesianIndex(I..., J...), nothing) - if !isnothing(vC) - C[I..., J...] = scale!!(vC, β) + jB = iB + while jB < length(keysB) + if CartesianIndex(TT.getindices(keysB[jB + 1].I, codomainind(B))) == Ic + jB += 1 + else + break + end + end + rA = iA:jA + rB = iB:jB + if length(rA) < length(rB) + for kB in rB + IB = keysB[kB] + IBo = CartesianIndex(TT.getindices(IB.I, domainind(B))) + vB = B[IB] + for kA in rA + IA = keysA[kA] + IAo = CartesianIndex(TT.getindices(IA.I, codomainind(A))) + IABo = CartesianIndex(IAo, IBo) + IC = CartesianIndex(TT.getindices(IABo.I, allind(C))) + vA = A[IA] + increasemulindex!(C, vA, vB, α, One(), IC) + end + end + else + for kA in rA + IA = keysA[kA] + IAo = CartesianIndex(TT.getindices(IA.I, codomainind(A))) + vA = A[IA] + for kB in rB + IB = keysB[kB] + IBo = CartesianIndex(TT.getindices(IB.I, domainind(B))) + vB = B[IB] + IABo = CartesianIndex(IAo, IBo) + IC = CartesianIndex(TT.getindices(IABo.I, allind(C))) + increasemulindex!(C, vA, vB, α, One(), IC) + end + end end + iA = jA + 1 + iB = jB + 1 + elseif IAc < IBc + iA += 1 + else + iB += 1 end end return C end +@inline function increasemulindex!( + C::AbstractBlockTensorMap, + A::AbstractTensorMap, + B::AbstractTensorMap, + α::Number, + β::Number, + I, +) + if haskey(C, I) + C[I] = _mul!!(C[I], A, B, α, β) + else + C[I] = _mul!!(nothing, A, B, α, β) + end +end + _mull!!(::Nothing, A, B, α::Number, β::Number) = scale!!(A * B, α) _mul!!(C, A, B, α::Number, β::Number) = add!!(C, A * B, α, β) const _TM_CAN_MUL = Union{TensorMap,AdjointTensorMap{<:TensorMap}} diff --git a/src/tensors/blocktensor.jl b/src/tensors/blocktensor.jl index 17f63b9..c86524e 100644 --- a/src/tensors/blocktensor.jl +++ b/src/tensors/blocktensor.jl @@ -231,7 +231,5 @@ end # Utility # ------- -Base.haskey(t::BlockTensorMap, I::CartesianIndex) = haskey(t.data, I) -function Base.haskey(t::BlockTensorMap, i::Int) - return haskey(t.data, CartesianIndices(t)[i]) -end +Base.haskey(t::BlockTensorMap, I::CartesianIndex) = checkbounds(Bool, t.data, I) +Base.haskey(t::BlockTensorMap, i::Int) = checkbounds(Bool, t.data, i)