Skip to content

Commit

Permalink
add multi-threading support for mul!, add! and tsvd!
Browse files Browse the repository at this point in the history
  • Loading branch information
qyli committed Jan 11, 2024
1 parent c097f3e commit df5b54f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 23 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
*-old
__pycache__
.ipynb*
Manifest.toml
Manifest.toml
.vscode
dev
80 changes: 70 additions & 10 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,24 +438,84 @@ end

# helper functions

function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
numthreads::Int=Threads.nthreads())
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!)
I = sectortype(t)
A = storagetype(t)
Udata = SectorDict{I,A}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
local Σdata
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)
if numthreads == 1
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)
end
dims[c] = length(Σ)
end
elseif numthreads == -1
tasks = map(blocksectors(t)) do c
Threads.@spawn MatrixAlgebra.svd!(blocks(t)[c], alg)
end
for (c, task) in zip(blocksectors(t), tasks)
U, Σ, V = fetch(task)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)
end
dims[c] = length(Σ)
end
else
Σdata = SectorDict{I,Vector{real(scalartype(t))}}()

# sort sectors by size
lsc = blocksectors(t)
lsD3 = map(lsc) do c
# O(D1^2D2) or O(D1D2^2)
return min(size(blocks(t)[c])[1]^2 * size(blocks(t)[c])[2],
size(blocks(t)[c])[1] * size(blocks(t)[c])[2]^2)
end
dims[c] = length(Σ)
lsc = lsc[sortperm(lsD3; rev=true)]

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in vcat(lsc, fill(nothing, numthreads))
put!(ch, c)
end
end

# consumers
Lock = Threads.SpinLock()
tasks = map(1:numthreads) do _
task = Threads.@spawn while true
c = take!(ch)
isnothing(c) && break
U, Σ, V = MatrixAlgebra.svd!(blocks(t)[c], alg)

# note inserting keys to dict is not thread safe
lock(Lock)
Udata[c] = U
Vdata[c] = V
Σdata[c] = Σ
dims[c] = length(Σ)
unlock(Lock)
end
errormonitor(task)
end

wait.(tasks)
wait(taskref[])

end
return Udata, Σdata, Vdata, dims
end
Expand Down
77 changes: 68 additions & 9 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,20 +239,79 @@ end
# TensorMap multiplication
function LinearAlgebra.mul!(tC::AbstractTensorMap,
tA::AbstractTensorMap,
tB::AbstractTensorMap, α=true, β=false)
tB::AbstractTensorMap, α=true, β=false;
numthreads::Int64=Threads.nthreads())
if !(codomain(tC) == codomain(tA) && domain(tC) == domain(tB) &&
domain(tA) == codomain(tB))
throw(SpaceMismatch("$(space(tC))$(space(tA)) * $(space(tB))"))
end
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
B = block(tB, c)
C = block(tC, c)
mul!(StridedView(C), StridedView(A), StridedView(B), α, β)
elseif β != one(β)
rmul!(block(tC, c), β)

if numthreads == 1
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
B = block(tB, c)
C = block(tC, c)
mul!(StridedView(C), StridedView(A), StridedView(B), α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
end
end

elseif numthreads == -1
Threads.@sync for c in blocksectors(tC)
if hasblock(tA, c)
Threads.@spawn mul!(StridedView(block(tC, c)),
StridedView(block(tA, c)),
StridedView(block(tB, c)),
α, β)
elseif β != one(β)
Threads.@spawn rmul!(block(tC, c), β)
end
end

else

# sort sectors by size
lsc = blocksectors(tC)
lsD3 = map(lsc) do c
if hasblock(tA, c)
return size(blocks(tA)[c], 1) * size(blocks(tA)[c], 2) *
size(blocks(tB)[c], 2)
else
return size(blocks(tC)[c], 1) * size(blocks(tC)[c], 2)
end
end
lsc = lsc[sortperm(lsD3; rev=true)]

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in vcat(lsc, fill(nothing, numthreads))
put!(ch, c)
end
end

# consumers
tasks = map(1:numthreads) do _
task = Threads.@spawn while true
c = take!(ch)
isnothing(c) && break

if hasblock(tA, c)
mul!(StridedView(block(tC, c)),
StridedView(block(tA, c)),
StridedView(block(tB, c)),
α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
end
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])
end
return tC
end
Expand Down
32 changes: 29 additions & 3 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,36 @@ function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap,
return VectorInterface.add!(scale!(similar(ty, T), ty, β), tx, α)
end
function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number, β::Number)
α::Number, β::Number;
numthreads::Int64=1)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
for c in blocksectors(tx)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
if numthreads == 1
for c in blocksectors(tx)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
elseif numthreads == -1
Threads.@sync for c in blocksectors(tx)
Threads.@spawn VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
else
# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in vcat(blocksectors(tx), fill(nothing, numthreads))
put!(ch, c)
end
end
# consumers
tasks = map(1:numthreads) do _
task = Threads.@spawn while true
c = take!(ch)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
return errormonitor(tast)
end

wait.(tasks)
wait(taskref[])
end
return ty
end
Expand Down

0 comments on commit df5b54f

Please sign in to comment.