Skip to content

Commit

Permalink
more mul! things
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Nov 19, 2024
1 parent 86c8036 commit 348161c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 29 deletions.
106 changes: 81 additions & 25 deletions src/linalg/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,95 @@ function LinearAlgebra.mul!(
compose(space(A), space(B)) == space(C) ||
throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))"))

Caxes = Base.Fix1(axes, C)
Aaxes = Base.Fix1(axes, A)

for I in Iterators.product(map(Caxes, codomainind(C))...),
J in Iterators.product(map(Caxes, domainind(C))...)

did_mul = false
for K in Iterators.product(map(Aaxes, domainind(A))...)
vA = get(A, CartesianIndex(I..., K...), nothing)
isnothing(vA) && continue
vB = get(B, CartesianIndex(K..., J...), nothing)
isnothing(vB) && continue

vC = get(C, CartesianIndex(I..., J...), nothing)
if did_mul
C[I..., J...] = _mul!!(vC, vA, vB, α, One())
else
C[I..., J...] = _mul!!(vC, vA, vB, α, β)
did_mul = true
scale!(C, β)

sortIA(IA) = CartesianIndex(TT.getindices(IA.I, domainind(A)))
keysA = sort!(vec(collect(nonzero_keys(A))); by=sortIA)
sortIB(IB) = CartesianIndex(TT.getindices(IB.I, codomainind(B)))
keysB = sort!(vec(collect(nonzero_keys(B))); by=sortIB)

iA = iB = 1
@inbounds while iA <= length(keysA) && iB <= length(keysB)
IA = keysA[iA]
IB = keysB[iB]
IAc = CartesianIndex(TT.getindices(IA.I, domainind(A)))
IBc = CartesianIndex(TT.getindices(IB.I, codomainind(B)))
if IAc == IBc
Ic = IAc
jA = iA
while jA < length(keysA)
if CartesianIndex(TT.getindices(keysA[jA + 1].I, domainind(A))) == Ic
jA += 1
else
break
end
end
end
# handle `β`
if !did_mul
vC = get(C, CartesianIndex(I..., J...), nothing)
if !isnothing(vC)
C[I..., J...] = scale!!(vC, β)
jB = iB
while jB < length(keysB)
if CartesianIndex(TT.getindices(keysB[jB + 1].I, codomainind(B))) == Ic
jB += 1
else
break
end
end
rA = iA:jA
rB = iB:jB
if length(rA) < length(rB)
for kB in rB
IB = keysB[kB]
IBo = CartesianIndex(TT.getindices(IB.I, domainind(B)))
vB = B[IB]
for kA in rA
IA = keysA[kA]
IAo = CartesianIndex(TT.getindices(IA.I, codomainind(A)))
IABo = CartesianIndex(IAo, IBo)
IC = CartesianIndex(TT.getindices(IABo.I, allind(C)))
vA = A[IA]
increasemulindex!(C, vA, vB, α, One(), IC)
end
end
else
for kA in rA
IA = keysA[kA]
IAo = CartesianIndex(TT.getindices(IA.I, codomainind(A)))
vA = A[IA]
for kB in rB
IB = keysB[kB]
IBo = CartesianIndex(TT.getindices(IB.I, domainind(B)))
vB = B[IB]
IABo = CartesianIndex(IAo, IBo)
IC = CartesianIndex(TT.getindices(IABo.I, allind(C)))
increasemulindex!(C, vA, vB, α, One(), IC)
end
end
end
iA = jA + 1
iB = jB + 1
elseif IAc < IBc
iA += 1
else
iB += 1
end
end

return C
end

@inline function increasemulindex!(
C::AbstractBlockTensorMap,
A::AbstractTensorMap,
B::AbstractTensorMap,
α::Number,
β::Number,
I,
)
if haskey(C, I)
C[I] = _mul!!(C[I], A, B, α, β)
else
C[I] = _mul!!(nothing, A, B, α, β)
end
end

_mull!!(::Nothing, A, B, α::Number, β::Number) = scale!!(A * B, α)
_mul!!(C, A, B, α::Number, β::Number) = add!!(C, A * B, α, β)
const _TM_CAN_MUL = Union{TensorMap,AdjointTensorMap{<:TensorMap}}
Expand Down
6 changes: 2 additions & 4 deletions src/tensors/blocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,5 @@ end

# Utility
# -------
Base.haskey(t::BlockTensorMap, I::CartesianIndex) = haskey(t.data, I)
function Base.haskey(t::BlockTensorMap, i::Int)
return haskey(t.data, CartesianIndices(t)[i])
end
Base.haskey(t::BlockTensorMap, I::CartesianIndex) = checkbounds(Bool, t.data, I)
Base.haskey(t::BlockTensorMap, i::Int) = checkbounds(Bool, t.data, i)

0 comments on commit 348161c

Please sign in to comment.