Skip to content

Commit

Permalink
Soften type restrictions on tsvd kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jun 28, 2024
1 parent 8b41fe3 commit 6f18e40
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,7 @@ function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
return isposdef!(permute(t, (p₁, p₂); copy=true))
end

function tsvd(t::AbstractTensorMap; trunc::TruncationScheme=NoTruncation(),
p::Real=2, alg::Union{SVD,SDD}=SDD())
return tsvd!(copy(t); trunc=trunc, p=p, alg=alg)
end
tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...)
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
return leftorth!(copy(t); alg=alg, kwargs...)
end
Expand Down Expand Up @@ -413,19 +410,17 @@ end
#------------------------------#
# Singular value decomposition #
#------------------------------#
function tsvd!(t::AdjointTensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg)
return adjoint(vt), adjoint(s), adjoint(u), err
end
function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
return _tsvd!(t, alg, trunc, p)
end

function tsvd!(t::TensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
#early return
# implementation dispatches on algorithm
function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2)
# early return
if isempty(blocksectors(t))
truncerr = zero(real(scalartype(t)))
return _empty_svdtensors(t)..., truncerr
Expand All @@ -440,6 +435,8 @@ function tsvd!(t::TensorMap;
else
truncerr = abs(zero(scalartype(t)))
W = S(dims)
# TODO: do we really want this behaviour? this changes the arrows of the S legs, but
# only sometimes?
if length(domain(t)) == 1 && domain(t)[1] W
W = domain(t)[1]
elseif length(codomain(t)) == 1 && codomain(t)[1] W
Expand Down

0 comments on commit 6f18e40

Please sign in to comment.