Skip to content

Commit

Permalink
compatible with tensors over ElementarySpace
Browse files Browse the repository at this point in the history
  • Loading branch information
qyli committed Jan 12, 2024
1 parent df5b54f commit f23c52f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
13 changes: 6 additions & 7 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
Udata = SectorDict{I,A}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
local Σdata
if numthreads == 1
if numthreads == 1 || length(blocksectors(t)) == 1
local Σdata
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Udata[c] = U
Expand Down Expand Up @@ -481,8 +481,8 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
lsc = blocksectors(t)
lsD3 = map(lsc) do c
# O(D1^2D2) or O(D1D2^2)
return min(size(blocks(t)[c])[1]^2 * size(blocks(t)[c])[2],
size(blocks(t)[c])[1] * size(blocks(t)[c])[2]^2)
return min(size(block(t, c), 1)^2 * size(block(t, c), 2),
size(block(t, c), 1) * size(block(t, c), 2)^2)
end
lsc = lsc[sortperm(lsD3; rev=true)]

Expand All @@ -509,13 +509,12 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
Σdata[c] = Σ
dims[c] = length(Σ)
unlock(Lock)
end
errormonitor(task)
end
return errormonitor(task)
end

wait.(tasks)
wait(taskref[])

end
return Udata, Σdata, Vdata, dims
end
Expand Down
8 changes: 4 additions & 4 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
throw(SpaceMismatch("$(space(tC))$(space(tA)) * $(space(tB))"))
end

if numthreads == 1
if numthreads == 1 || length(blocksectors(tC)) == 1
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
Expand Down Expand Up @@ -276,10 +276,10 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
lsc = blocksectors(tC)
lsD3 = map(lsc) do c
if hasblock(tA, c)
return size(blocks(tA)[c], 1) * size(blocks(tA)[c], 2) *
size(blocks(tB)[c], 2)
return size(block(tA, c), 1) * size(block(tA, c), 2) *
size(block(tA, c), 2)
else
return size(blocks(tC)[c], 1) * size(blocks(tC)[c], 2)
return size(block(tC, c), 1) * size(block(tC, c), 2)
end
end
lsc = lsc[sortperm(lsD3; rev=true)]
Expand Down
4 changes: 2 additions & 2 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number, β::Number;
numthreads::Int64=1)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
if numthreads == 1
if numthreads == 1 || length(blocksectors) == 1
for c in blocksectors(tx)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
Expand All @@ -83,7 +83,7 @@ function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
c = take!(ch)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
return errormonitor(tast)
return errormonitor(task)
end

wait.(tasks)
Expand Down
10 changes: 10 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"

0 comments on commit f23c52f

Please sign in to comment.