Skip to content

Commit

Permalink
Change default SVD algorithm to SVD
Browse files Browse the repository at this point in the history
- Use stable algorithm as the default.
  • Loading branch information
lkdvos committed Feb 13, 2024
1 parent 7391fe0 commit 8b0aed3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/TensorKitManifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export Grassmann, Stiefel, Unitary
export inner, retract, transport, transport!

using TensorKit
using TensorKit: SVD

# Every submodule -- Grassmann, Stiefel, and Unitary -- implements their own methods for
# these. The signatures should be
Expand Down Expand Up @@ -61,7 +62,7 @@ end

struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm
end
function projectisometric!(W::AbstractTensorMap; alg=Polar())
function projectisometric!(W::AbstractTensorMap; alg=TensorKit.SVD())
if alg isa TensorKit.Polar || alg isa TensorKit.SDD
foreach(blocks(W)) do (c, b)
return _polarsdd!(b)
Expand Down
12 changes: 6 additions & 6 deletions src/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function inner(W::AbstractTensorMap, Δ₁::GrassmannTangent, Δ₂::GrassmannTa
end

"""
retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = nothing)
retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = TensorKit.SVD())
Retract isometry `W == base(Δ)` within the Grassmann manifold using tangent vector `Δ.Z`.
If the singular value decomposition of `Z` is given by `U * S * V`, then the resulting
Expand All @@ -169,28 +169,28 @@ while the local tangent vector along the retraction curve is
`Z′ = - W * V' * sin(α*S) * S * V + U * cos(α * S) * S * V'`.
"""
function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=nothing)
function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=SVD())
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
U, S, V = Δ.U, Δ.S, Δ.V
WVd = W * V'
sSV, cSV = _sincosSV(α, S, V) # sin(S)*V, cos(S)*V
W′ = projectisometric!(WVd * cSV + U * sSV)
W′ = projectisometric!(WVd * cSV + U * sSV; alg)
sSSV = _lmul!(S, sSV) # sin(S)*S*V
cSSV = _lmul!(S, cSV) # cos(S)*S*V
Z′ = projectcomplement!(-WVd * sSSV + U * cSSV, W′)
return W′, GrassmannTangent(W′, Z′)
end

"""
Grassmann.invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = nothing)
Grassmann.invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = SVD())
Return the Grassmann tangent `Z` and unitary `Y` such that `retract(Wold, Z, 1) * Y ≈ Wnew`.
This is done by solving the equation `Wold * V' * cos(S) * V + U * sin(S) * V = Wnew * Y'`
for the isometries `U`, `V`, and `Y`, and the diagonal matrix `S`, and returning
`Z = U * S * V` and `Y`.
"""
function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothing)
function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=SVD())
space(Wold) == space(Wnew) || throw(SectorMismatch())
WodWn = Wold' * Wnew # V' * cos(S) * V * Y
Wneworth = Wnew - Wold * WodWn
Expand All @@ -199,7 +199,7 @@ function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothin
# acos always returns a complex TensorMap. We cast back to real if possible.
S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx
UsS = Wneworth * VY' # U * sin(S) # should be in polar decomposition form
U = projectisometric!(UsS; alg=Polar())
U = projectisometric!(UsS; alg=SVD())
Y = Vd * VY
V = Vd'
Z = Grassmann.GrassmannTangent(Wold, U * S * V)
Expand Down
4 changes: 2 additions & 2 deletions src/unitary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ end
project(X, W; metric=:euclidean) = project!(copy(X), W; metric=:euclidean)

# geodesic retraction, coincides with Stiefel retraction (which is not geodesic for p < n)
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=nothing)
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=SVD())
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
E = exp* Δ.A)
W′ = projectisometric!(W * E; alg=SDD())
W′ = projectisometric!(W * E; alg)
A′ = Δ.A
return W′, UnitaryTangent(W′, A′)
end
Expand Down

0 comments on commit 8b0aed3

Please sign in to comment.