Skip to content

Commit

Permalink
Soften type restrictions on tsvd kwargs (#134)
Browse files Browse the repository at this point in the history
* Soften type restrictions on tsvd kwargs

* Fix spaces in truncated svd to non-dual
  • Loading branch information
lkdvos authored Jul 2, 2024
1 parent 8b41fe3 commit 3275ffb
Showing 1 changed file with 12 additions and 23 deletions.
35 changes: 12 additions & 23 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,7 @@ function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
return isposdef!(permute(t, (p₁, p₂); copy=true))
end

function tsvd(t::AbstractTensorMap; trunc::TruncationScheme=NoTruncation(),
p::Real=2, alg::Union{SVD,SDD}=SDD())
return tsvd!(copy(t); trunc=trunc, p=p, alg=alg)
end
tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...)
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
return leftorth!(copy(t); alg=alg, kwargs...)
end
Expand Down Expand Up @@ -413,39 +410,31 @@ end
#------------------------------#
# Singular value decomposition #
#------------------------------#
function tsvd!(t::AdjointTensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg)
return adjoint(vt), adjoint(s), adjoint(u), err
end
function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
return _tsvd!(t, alg, trunc, p)
end

function tsvd!(t::TensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
#early return
# implementation dispatches on algorithm
function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2)
# early return
if isempty(blocksectors(t))
truncerr = zero(real(scalartype(t)))
return _empty_svdtensors(t)..., truncerr
end

S = spacetype(t)
Udata, Σdata, Vdata, dims = _compute_svddata!(t, alg)
if !isa(trunc, NoTruncation)
if trunc isa NoTruncation
truncerr = abs(zero(scalartype(t)))
else
Σdata, truncerr = _truncate!(Σdata, trunc, p)
Udata, Σdata, Vdata, dims = _implement_svdtruncation!(t, Udata, Σdata, Vdata, dims)
W = S(dims)
else
truncerr = abs(zero(scalartype(t)))
W = S(dims)
if length(domain(t)) == 1 && domain(t)[1] W
W = domain(t)[1]
elseif length(codomain(t)) == 1 && codomain(t)[1] W
W = codomain(t)[1]
end
end
W = S(dims)
return _create_svdtensors(t, Udata, Σdata, Vdata, W)..., truncerr
end

Expand Down

0 comments on commit 3275ffb

Please sign in to comment.