From d41cbbc6ed980baedb22befbf84a42a07f658dda Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 30 Dec 2024 15:19:02 +0100 Subject: [PATCH] update TensorKit compat --- Project.toml | 2 +- src/linalg/factorizations.jl | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f26ac20..d4f2344 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ LinearAlgebra = "1" Random = "1" Strided = "2" SafeTestsets = "0.1" -TensorKit = "0.13.2" +TensorKit = "0.13.2, 0.14" TensorOperations = "5" Test = "1" TestExtras = "0.2" diff --git a/src/linalg/factorizations.jl b/src/linalg/factorizations.jl index d68742a..8281e5a 100644 --- a/src/linalg/factorizations.jl +++ b/src/linalg/factorizations.jl @@ -213,6 +213,27 @@ function TK.tsvd!(t::SparseBlockTensorMap; kwargs...) return tsvd!(BlockTensorMap(t); kwargs...) end +function TK._tsvd!( + t::BlockTensorMap, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2 +) + # early return + if isempty(blocksectors(t)) + truncerr = zero(real(scalartype(t))) + return TK._empty_svdtensors(t)..., truncerr + end + + # compute SVD factorization for each block + S = spacetype(t) + SVDdata, dims = TK._compute_svddata!(t, alg) + Σdata = SectorDict(c => Σ for (c, (U, Σ, V)) in SVDdata) + truncdim = TK._compute_truncdim(Σdata, trunc, p) + truncerr = TK._compute_truncerr(Σdata, truncdim, p) + + # construct output tensors + U, Σ, V⁺ = TK._create_svdtensors(t, SVDdata, truncdim) + return U, Σ, V⁺, truncerr +end + function TK._compute_svddata!(t::AbstractBlockTensorMap, alg::Union{SVD,SDD}) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) I = sectortype(t) @@ -225,7 +246,6 @@ function TK._compute_svddata!(t::AbstractBlockTensorMap, alg::Union{SVD,SDD}) SVDdata = SectorDict(generator) return SVDdata, dims end - function TK._create_svdtensors(t::AbstractBlockTensorMap, SVDdata, dims) S = spacetype(t) W = S(dims) @@ -241,3 +261,16 @@ function TK._create_svdtensors(t::AbstractBlockTensorMap, SVDdata, dims) end return U, Σ, V⁺ end + +function TK._empty_svdtensors(t::AbstractBlockTensorMap) + T = scalartype(t) + S = spacetype(t) + I = sectortype(t) + dims = SectorDict{I,Int}() + W = S(dims) + + U = similar(t, codomain(t) ← W) + Σ = similar(t, real(T), W ← W) + V⁺ = similar(t, W ← domain(t)) + return U, Σ, V⁺ +end