From 3275ffbe04a51a96d83f871415d22a5a3c28c486 Mon Sep 17 00:00:00 2001 From: Lukas <37111893+lkdvos@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:22:19 +0200 Subject: [PATCH] Soften type restrictions on tsvd kwargs (#134) * Soften type restrictions on tsvd kwargs * Fix spaces in truncated svd to non-dual --- src/tensors/factorizations.jl | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 524e2760..6b325d48 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -246,10 +246,7 @@ function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) return isposdef!(permute(t, (p₁, p₂); copy=true)) end -function tsvd(t::AbstractTensorMap; trunc::TruncationScheme=NoTruncation(), - p::Real=2, alg::Union{SVD,SDD}=SDD()) - return tsvd!(copy(t); trunc=trunc, p=p, alg=alg) -end +tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...) function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...) return leftorth!(copy(t); alg=alg, kwargs...) end @@ -413,19 +410,17 @@ end #------------------------------# # Singular value decomposition # #------------------------------# -function tsvd!(t::AdjointTensorMap; - trunc::TruncationScheme=NoTruncation(), - p::Real=2, - alg::Union{SVD,SDD}=SDD()) +function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) return adjoint(vt), adjoint(s), adjoint(u), err end +function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) + return _tsvd!(t, alg, trunc, p) +end -function tsvd!(t::TensorMap; - trunc::TruncationScheme=NoTruncation(), - p::Real=2, - alg::Union{SVD,SDD}=SDD()) - #early return +# implementation dispatches on algorithm +function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2) + # early return if isempty(blocksectors(t)) truncerr = zero(real(scalartype(t))) return _empty_svdtensors(t)..., truncerr @@ -433,19 +428,13 @@ function tsvd!(t::TensorMap; S = spacetype(t) Udata, Σdata, Vdata, dims = _compute_svddata!(t, alg) - if !isa(trunc, NoTruncation) + if trunc isa NoTruncation + truncerr = abs(zero(scalartype(t))) + else Σdata, truncerr = _truncate!(Σdata, trunc, p) Udata, Σdata, Vdata, dims = _implement_svdtruncation!(t, Udata, Σdata, Vdata, dims) - W = S(dims) - else - truncerr = abs(zero(scalartype(t))) - W = S(dims) - if length(domain(t)) == 1 && domain(t)[1] ≅ W - W = domain(t)[1] - elseif length(codomain(t)) == 1 && codomain(t)[1] ≅ W - W = codomain(t)[1] - end end + W = S(dims) return _create_svdtensors(t, Udata, Σdata, Vdata, W)..., truncerr end