Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multi-threading support for mul!, add! and tsvd! #100

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
*-old
__pycache__
.ipynb*
Manifest.toml
Manifest.toml
.vscode
dev
7 changes: 7 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,17 @@ include("planar/planaroperations.jl")
# deprecations: to be removed in version 1.0 or sooner
include("auxiliary/deprecate.jl")

include("multithreading.jl")
# Extensions
# ----------
function __init__()
@require_extensions

global nthreads_mul = Threads.nthreads()
global nthreads_eigh = Threads.nthreads()
global nthreads_svd = Threads.nthreads()
global nthreads_add = Threads.nthreads()

end

end
33 changes: 33 additions & 0 deletions src/multithreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# global variables to control multi-threading behaviors
global nthreads_mul::Int64
global nthreads_eigh::Int64
global nthreads_svd::Int64
global nthreads_add::Int64

function set_num_threads_mul(n::Int64)
@assert 1 ≤ n ≤ Threads.nthreads()
global nthreads_mul = n
return nothing
end
get_num_threads_mul() = nthreads_mul

function set_num_threads_add(n::Int64)
@assert 1 ≤ n ≤ Threads.nthreads()
global nthreads_add = n
return nothing
end
get_num_threads_add() = nthreads_add

function set_num_threads_svd(n::Int64)
@assert 1 ≤ n ≤ Threads.nthreads()
global nthreads_svd = n
return nothing
end
get_num_threads_svd() = nthreads_svd

function set_num_threads_eigh(n::Int64)
@assert 1 ≤ n ≤ Threads.nthreads()
global nthreads_eigh = n
return nothing
end
get_num_threads_eigh() = nthreads_eigh
126 changes: 110 additions & 16 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,17 +445,67 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
Udata = SectorDict{I,A}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
local Σdata
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)

num_threads = get_num_threads_svd()
if num_threads == 1 || length(blocksectors(t)) == 1
local Σdata
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)
end
dims[c] = length(Σ)
end
else
Σdata = SectorDict{I,Vector{real(scalartype(t))}}()

# try to sort sectors by size
lsc = blocksectors(t)
if isa(lsc, AbstractVector)
lsD3 = lsD3 = map(lsc) do c
# O(D1^2D2) or O(D1D2^2)
return min(size(block(t, c), 1)^2 * size(block(t, c), 2),
size(block(t, c), 1) * size(block(t, c), 2)^2)
end
lsc = lsc[sortperm(lsD3; rev=true)]
end

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in lsc
put!(ch, c)
end
end
dims[c] = length(Σ)

# consumers
Lock = Threads.ReentrantLock()
tasks = map(1:num_threads) do _
task = Threads.@spawn for c in ch
U, Σ, V = MatrixAlgebra.svd!(blocks(t)[c], alg)

# note inserting keys to dict is not thread safe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid the lock? In principle it should be possible to "pre-allocate" the dictionaries, as in this case we know all keys in advance, and also that every entry will be visited just once

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid the lock? In principle it should be possible to "pre-allocate" the dictionaries, as in this case we know all keys in advance, and also that every entry will be visited just once

This is a good idea to avoid locks. I implement it as the following code,

    # preallocate
    for c in lsc
        sz = size(block(t, c))
        dims[c] = min(sz[1], sz[2])
        Udata[c] = similar(block(t, c), sz[1], dims[c])
        Vdata[c] = similar(block(t, c), dims[c], sz[2])
        Σdata[c] = similar(block(t, c), dims[c])
    end

    # producer
    taskref = Ref{Task}()
    ch = Channel(; taskref=taskref, spawn=true) do ch
        for c in lsc
            put!(ch, c)
        end
    end

    # consumers
    tasks = map(1:ntasks) do _
        task = Threads.@spawn for c in ch
            U, Σ, V = MatrixAlgebra.svd!(block(t, c), alg)
            copyto!(Udata[c], U)
            copyto!(Vdata[c], V)
            copyto!(Σdata[c], Σ)
        end
        return errormonitor(task)
    end

However, I found it will not change the time cost but lead to some additional memory allocations due to the temporary local variables U, Σ, V. I also tried

      Udata[c][:], Σdata[c][:], Vdata[c][:] = MatrixAlgebra.svd!(block(t, c), alg)

Unfortunately, this naive try cannot solve the problem because the compiler does not optimize this to an inplace one, see the following example,

julia> Meta.@lower U[:], S[:], V[:] = svd!(A)
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1%1 = svd!(A)
│   %2 = Base.indexed_iterate(%1, 1)
│   %3 = Core.getfield(%2, 1)
│        #s49 = Core.getfield(%2, 2)%5 = Base.indexed_iterate(%1, 2, #s49)%6 = Core.getfield(%5, 1)
│        #s49 = Core.getfield(%5, 2)%8 = Base.indexed_iterate(%1, 3, #s49)%9 = Core.getfield(%8, 1)
│        Base.setindex!(U, %3, :)
│        Base.setindex!(S, %6, :)
│        Base.setindex!(V, %9, :)
└──      return %1
))))

In conclusion, I think the original implementation is better. If you have any idea to avoid the additional allocations, please tell me. I will be happy to learn it.

lock(Lock)
try
Udata[c] = U
Vdata[c] = V
Σdata[c] = Σ
dims[c] = length(Σ)
catch
rethrow()
finally
unlock(Lock)
end
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])
end
return Udata, Σdata, Vdata, dims
end
Expand Down Expand Up @@ -518,13 +568,57 @@ function eigh!(t::TensorMap)
Ddata = SectorDict{I,Ar}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eigh!(b)
d = length(values)
Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values))
Vdata[c] = vectors
dims[c] = d

num_threads = get_num_threads_eigh()
lsc = blocksectors(t)
if num_threads == 1 || length(lsc) == 1
for c in lsc
values, vectors = MatrixAlgebra.eigh!(block(t, c))
d = length(values)
Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values))
Vdata[c] = vectors
dims[c] = d
end
else
# try to sort sectors by size
if isa(lsc, AbstractVector)
lsc = sort(lsc; by=c -> size(block(t, c), 1), rev=true)
end

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in lsc
put!(ch, c)
end
end

# consumers
Lock = Threads.ReentrantLock()
tasks = map(1:num_threads) do _
task = Threads.@spawn for c in ch
values, vectors = MatrixAlgebra.eigh!(block(t, c))
d = length(values)
values = copyto!(similar(values, (d, d)), Diagonal(values))

lock(Lock)
try
Ddata[c] = values
Vdata[c] = vectors
dims[c] = d
catch
rethrow()
finally
unlock(Lock)
end
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])
end

if length(domain(t)) == 1
W = domain(t)[1]
else
Expand Down
61 changes: 53 additions & 8 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,60 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
domain(tA) == codomain(tB))
throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
end
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
B = block(tB, c)
C = block(tC, c)
mul!(StridedView(C), StridedView(A), StridedView(B), α, β)
elseif β != one(β)
rmul!(block(tC, c), β)

num_threads = get_num_threads_mul()
if num_threads == 1 || length(blocksectors(tC)) == 1
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
B = block(tB, c)
C = block(tC, c)
mul!(StridedView(C), StridedView(A), StridedView(B), α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
end
end

else
lsc = blocksectors(tC)
# try to sort sectors by size
if isa(lsc, AbstractVector)
lsD3 = map(lsc) do c
if hasblock(tA, c)
return size(block(tA, c), 1) * size(block(tA, c), 2) *
size(block(tA, c), 2)
else
return size(block(tC, c), 1) * size(block(tC, c), 2)
end
end
end
lsc = lsc[sortperm(lsD3; rev=true)]

# producer
taskref = Ref{Task}()
ch = Channel(; taskref=taskref, spawn=true) do ch
for c in lsc
put!(ch, c)
end
end

# consumers
tasks = map(1:num_threads) do _
task = Threads.@spawn for c in ch
if hasblock(tA, c)
mul!(block(tC, c),
block(tA, c),
block(tB, c),
α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
end
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])
end
return tC
end
Expand Down
25 changes: 23 additions & 2 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,29 @@ end
function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))"))
for c in blocksectors(tx)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
num_threads = get_num_threads_add()
lsc = blocksectors(tx)
if num_threads == 1 || length(lsc) == 1
for c in lsc
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
else
# try to sort sectors by size
if isa(lsc, AbstractVector)
# warning: using `sort!` here is not safe. I found it will lead to a "key ... not found" error when show tx again
lsc = sort(lsc; by=c -> prod(size(block(tx, c))), rev=true)
end

idx = Threads.Atomic{Int64}(1)
Threads.@sync for _ in 1:num_threads
Threads.@spawn while true
i = Threads.atomic_add!(idx, 1)
i > length(lsc) && break

c = lsc[i]
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
end
end
return ty
end
Expand Down
Loading