Skip to content

Commit

Permalink
update implementations & add multi-threading eigh and interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
qyli committed Feb 7, 2024
1 parent c9a8ef7 commit a4de69a
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 96 deletions.
7 changes: 7 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,17 @@ include("planar/planaroperations.jl")
# deprecations: to be removed in version 1.0 or sooner
include("auxiliary/deprecate.jl")

include("multithreading.jl")
# Extensions
# ----------
function __init__()
@require_extensions

global nthreads_mul = Threads.nthreads()
global nthreads_eigh = Threads.nthreads()
global nthreads_svd = Threads.nthreads()
global nthreads_add = Threads.nthreads()

end

end
33 changes: 33 additions & 0 deletions src/multithreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# global variables to control multi-threading behaviors
global nthreads_mul::Int64
global nthreads_eigh::Int64
global nthreads_svd::Int64
global nthreads_add::Int64

function set_num_threads_mul(n::Int64)
@assert 1 n Threads.nthreads()
global nthreads_mul = n
return nothing
end
get_num_threads_mul() = nthreads_mul

function set_num_threads_add(n::Int64)
@assert 1 n Threads.nthreads()
global nthreads_add = n
return nothing
end
get_num_threads_add() = nthreads_add

function set_num_threads_svd(n::Int64)
@assert 1 n Threads.nthreads()
global nthreads_svd = n
return nothing
end
get_num_threads_svd() = nthreads_svd

function set_num_threads_eigh(n::Int64)
@assert 1 n Threads.nthreads()
global nthreads_eigh = n
return nothing
end
get_num_threads_eigh() = nthreads_eigh
117 changes: 76 additions & 41 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,15 +438,16 @@ end

# helper functions

function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
numthreads::Int=Threads.nthreads())
function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!)
I = sectortype(t)
A = storagetype(t)
Udata = SectorDict{I,A}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
if numthreads == 1 || length(blocksectors(t)) == 1

num_threads = get_num_threads_svd()
if num_threads == 1 || length(blocksectors(t)) == 1
local Σdata
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Expand All @@ -459,56 +460,46 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
end
dims[c] = length(Σ)
end
elseif numthreads == -1
tasks = map(blocksectors(t)) do c
Threads.@spawn MatrixAlgebra.svd!(blocks(t)[c], alg)
end
for (c, task) in zip(blocksectors(t), tasks)
U, Σ, V = fetch(task)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)
end
dims[c] = length(Σ)
end
else
Σdata = SectorDict{I,Vector{real(scalartype(t))}}()

# sort sectors by size
# try to sort sectors by size
lsc = blocksectors(t)
lsD3 = map(lsc) do c
# O(D1^2D2) or O(D1D2^2)
return min(size(block(t, c), 1)^2 * size(block(t, c), 2),
size(block(t, c), 1) * size(block(t, c), 2)^2)
if isa(lsc, AbstractVector)
lsD3 = lsD3 = map(lsc) do c
# O(D1^2D2) or O(D1D2^2)
return min(size(block(t, c), 1)^2 * size(block(t, c), 2),
size(block(t, c), 1) * size(block(t, c), 2)^2)
end
lsc = lsc[sortperm(lsD3; rev=true)]
end
lsc = lsc[sortperm(lsD3; rev=true)]

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in vcat(lsc, fill(nothing, numthreads))
for c in lsc
put!(ch, c)
end
end

# consumers
Lock = Threads.SpinLock()
tasks = map(1:numthreads) do _
task = Threads.@spawn while true
c = take!(ch)
isnothing(c) && break
Lock = Threads.ReentrantLock()
tasks = map(1:num_threads) do _
task = Threads.@spawn for c in ch
U, Σ, V = MatrixAlgebra.svd!(blocks(t)[c], alg)

# note inserting keys to dict is not thread safe
lock(Lock)
Udata[c] = U
Vdata[c] = V
Σdata[c] = Σ
dims[c] = length(Σ)
unlock(Lock)
try
Udata[c] = U
Vdata[c] = V
Σdata[c] = Σ
dims[c] = length(Σ)
catch
rethrow()
finally
unlock(Lock)
end
end
return errormonitor(task)
end
Expand Down Expand Up @@ -577,13 +568,57 @@ function eigh!(t::TensorMap)
Ddata = SectorDict{I,Ar}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eigh!(b)
d = length(values)
Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values))
Vdata[c] = vectors
dims[c] = d

num_threads = get_num_threads_eigh()
lsc = blocksectors(t)
if num_threads == 1 || length(lsc) == 1
for c in lsc
values, vectors = MatrixAlgebra.eigh!(block(t, c))
d = length(values)
Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values))
Vdata[c] = vectors
dims[c] = d
end
else
# try to sort sectors by size
if isa(lsc, AbstractVector)
lsc = sort(lsc; by=c -> size(block(t, c), 1), rev=true)
end

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in lsc
put!(ch, c)
end
end

# consumers
Lock = Threads.ReentrantLock()
tasks = map(1:num_threads) do _
task = Threads.@spawn for c in ch
values, vectors = MatrixAlgebra.eigh!(block(t, c))
d = length(values)
values = copyto!(similar(values, (d, d)), Diagonal(values))

lock(Lock)
try
Ddata[c] = values
Vdata[c] = vectors
dims[c] = d
catch
rethrow()
finally
unlock(Lock)
end
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])
end

if length(domain(t)) == 1
W = domain(t)[1]
else
Expand Down
50 changes: 18 additions & 32 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,14 @@ end
# TensorMap multiplication
function LinearAlgebra.mul!(tC::AbstractTensorMap,
tA::AbstractTensorMap,
tB::AbstractTensorMap, α=true, β=false;
numthreads::Int64=Threads.nthreads())
tB::AbstractTensorMap, α=true, β=false)
if !(codomain(tC) == codomain(tA) && domain(tC) == domain(tB) &&
domain(tA) == codomain(tB))
throw(SpaceMismatch("$(space(tC))$(space(tA)) * $(space(tB))"))
end

if numthreads == 1 || length(blocksectors(tC)) == 1
num_threads = get_num_threads_mul()
if num_threads == 1 || length(blocksectors(tC)) == 1
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
Expand All @@ -258,50 +258,36 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
end
end

elseif numthreads == -1
Threads.@sync for c in blocksectors(tC)
if hasblock(tA, c)
Threads.@spawn mul!(StridedView(block(tC, c)),
StridedView(block(tA, c)),
StridedView(block(tB, c)),
α, β)
elseif β != one(β)
Threads.@spawn rmul!(block(tC, c), β)
end
end

else

# sort sectors by size
lsc = blocksectors(tC)
lsD3 = map(lsc) do c
if hasblock(tA, c)
return size(block(tA, c), 1) * size(block(tA, c), 2) *
size(block(tA, c), 2)
else
return size(block(tC, c), 1) * size(block(tC, c), 2)
# try to sort sectors by size
if isa(lsc, AbstractVector)
lsD3 = map(lsc) do c
if hasblock(tA, c)
return size(block(tA, c), 1) * size(block(tA, c), 2) *
size(block(tA, c), 2)
else
return size(block(tC, c), 1) * size(block(tC, c), 2)
end
end
end
lsc = lsc[sortperm(lsD3; rev=true)]

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in vcat(lsc, fill(nothing, numthreads))
for c in lsc
put!(ch, c)
end
end

# consumers
tasks = map(1:numthreads) do _
task = Threads.@spawn while true
c = take!(ch)
isnothing(c) && break

tasks = map(1:num_threads) do _
task = Threads.@spawn for c in ch
if hasblock(tA, c)
mul!(StridedView(block(tC, c)),
StridedView(block(tA, c)),
StridedView(block(tB, c)),
mul!(block(tC, c),
block(tA, c),
block(tB, c),
α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
Expand Down
41 changes: 18 additions & 23 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,31 @@ function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap,
return VectorInterface.add!(scale!(similar(ty, T), ty, β), tx, α)
end
function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number, β::Number;
numthreads::Int64=1)
α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
if numthreads == 1 || length(blocksectors) == 1
for c in blocksectors(tx)
num_threads = get_num_threads_add()
lsc = blocksectors(tx)
if num_threads == 1 || length(lsc) == 1
for c in lsc
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
elseif numthreads == -1
Threads.@sync for c in blocksectors(tx)
Threads.@spawn VectorInterface.add!(block(ty, c), block(tx, c), α, β)
else
# try to sort sectors by size
if isa(lsc, AbstractVector)
# warning: using `sort!` here is not safe. I found it will lead to a "key ... not found" error when show tx again
lsc = sort(lsc; by=c -> prod(size(block(tx, c))), rev=true)
end
else
# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in vcat(blocksectors(tx), fill(nothing, numthreads))
put!(ch, c)
end
end
# consumers
tasks = map(1:numthreads) do _
task = Threads.@spawn while true
c = take!(ch)

idx = Threads.Atomic{Int64}(1)
Threads.@sync for _ in 1:num_threads
Threads.@spawn while true
i = Threads.atomic_add!(idx, 1)
i > length(lsc) && break

c = lsc[i]
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])
end
return ty
end
Expand Down

0 comments on commit a4de69a

Please sign in to comment.