diff --git a/src/TensorKitManifolds.jl b/src/TensorKitManifolds.jl index 793de7e..447336e 100644 --- a/src/TensorKitManifolds.jl +++ b/src/TensorKitManifolds.jl @@ -7,6 +7,7 @@ export Grassmann, Stiefel, Unitary export inner, retract, transport, transport! using TensorKit +using TensorKit: SVD # Every submodule -- Grassmann, Stiefel, and Unitary -- implements their own methods for # these. The signatures should be @@ -61,7 +62,7 @@ end struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm end -function projectisometric!(W::AbstractTensorMap; alg=Polar()) +function projectisometric!(W::AbstractTensorMap; alg=TensorKit.SVD()) if alg isa TensorKit.Polar || alg isa TensorKit.SDD foreach(blocks(W)) do (c, b) return _polarsdd!(b) diff --git a/src/grassmann.jl b/src/grassmann.jl index 0f63c6f..2cdd06a 100644 --- a/src/grassmann.jl +++ b/src/grassmann.jl @@ -157,7 +157,7 @@ function inner(W::AbstractTensorMap, Δ₁::GrassmannTangent, Δ₂::GrassmannTa end """ - retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = nothing) + retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = TensorKit.SVD()) Retract isometry `W == base(Δ)` within the Grassmann manifold using tangent vector `Δ.Z`. If the singular value decomposition of `Z` is given by `U * S * V`, then the resulting @@ -169,12 +169,12 @@ while the local tangent vector along the retraction curve is `Z′ = - W * V' * sin(α*S) * S * V + U * cos(α * S) * S * V'`. """ -function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=nothing) +function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=SVD()) W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point")) U, S, V = Δ.U, Δ.S, Δ.V WVd = W * V' sSV, cSV = _sincosSV(α, S, V) # sin(S)*V, cos(S)*V - W′ = projectisometric!(WVd * cSV + U * sSV) + W′ = projectisometric!(WVd * cSV + U * sSV; alg) sSSV = _lmul!(S, sSV) # sin(S)*S*V cSSV = _lmul!(S, cSV) # cos(S)*S*V Z′ = projectcomplement!(-WVd * sSSV + U * cSSV, W′) @@ -182,7 +182,7 @@ function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=nothing) end """ - Grassmann.invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = nothing) + Grassmann.invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = SVD()) Return the Grassmann tangent `Z` and unitary `Y` such that `retract(Wold, Z, 1) * Y ≈ Wnew`. @@ -190,7 +190,7 @@ This is done by solving the equation `Wold * V' * cos(S) * V + U * sin(S) * V = for the isometries `U`, `V`, and `Y`, and the diagonal matrix `S`, and returning `Z = U * S * V` and `Y`. """ -function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothing) +function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=SVD()) space(Wold) == space(Wnew) || throw(SectorMismatch()) WodWn = Wold' * Wnew # V' * cos(S) * V * Y Wneworth = Wnew - Wold * WodWn @@ -199,7 +199,7 @@ function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothin # acos always returns a complex TensorMap. We cast back to real if possible. S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx UsS = Wneworth * VY' # U * sin(S) # should be in polar decomposition form - U = projectisometric!(UsS; alg=Polar()) + U = projectisometric!(UsS; alg=SVD()) Y = Vd * VY V = Vd' Z = Grassmann.GrassmannTangent(Wold, U * S * V) diff --git a/src/unitary.jl b/src/unitary.jl index ea74042..6ebf7b6 100644 --- a/src/unitary.jl +++ b/src/unitary.jl @@ -82,10 +82,10 @@ end project(X, W; metric=:euclidean) = project!(copy(X), W; metric=:euclidean) # geodesic retraction, coincides with Stiefel retraction (which is not geodesic for p < n) -function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=nothing) +function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=SVD()) W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point")) E = exp(α * Δ.A) - W′ = projectisometric!(W * E; alg=SDD()) + W′ = projectisometric!(W * E; alg) A′ = Δ.A return W′, UnitaryTangent(W′, A′) end