diff --git a/src/TensorKitManifolds.jl b/src/TensorKitManifolds.jl index 793de7e..900b784 100644 --- a/src/TensorKitManifolds.jl +++ b/src/TensorKitManifolds.jl @@ -28,6 +28,9 @@ checkbase(x, y, z, args...) = checkbase(checkbase(x, y), z, args...) # the machine epsilon for the elements of an object X, name inspired from eltype scalareps(X) = eps(real(scalartype(X))) +# default SVD algorithm used in the algorithms +default_svd_alg(::AbstractTensorMap) = TensorKit.SVD() + function isisometry(W::AbstractTensorMap; tol=10 * scalareps(W)) WdW = W' * W s = zero(float(real(scalartype(W)))) @@ -61,7 +64,7 @@ end struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm end -function projectisometric!(W::AbstractTensorMap; alg=Polar()) +function projectisometric!(W::AbstractTensorMap; alg=default_svd_alg(W)) if alg isa TensorKit.Polar || alg isa TensorKit.SDD foreach(blocks(W)) do (c, b) return _polarsdd!(b) @@ -98,7 +101,7 @@ projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W)) projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W)) function projectisometric(W::AbstractTensorMap; - alg::TensorKit.OrthogonalFactorizationAlgorithm=Polar()) + alg=default_svd_alg(W)) return projectisometric!(copy(W); alg=alg) end function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap, diff --git a/src/grassmann.jl b/src/grassmann.jl index 459e90c..d5810a7 100644 --- a/src/grassmann.jl +++ b/src/grassmann.jl @@ -10,9 +10,6 @@ using ..TensorKitManifolds: projecthermitian!, projectantihermitian!, projectisometric!, projectcomplement!, PolarNewton import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport! -# Default algorithm used in all of the SVD-based methods -# Picking SVD instead of SDD here, as it seems SDD has stability issues for tensors that are already close to isometry. -const DEFAULT_SVD_ALG = SVD() # special type to store tangent vectors using Z # add SVD of Z = U*S*V upon first creation @@ -60,7 +57,7 @@ function Base.getproperty(Δ::GrassmannTangent, sym::Symbol) elseif sym ∈ (:U, :S, :V) v = Base.getfield(Δ, sym) v !== nothing && return v - U, S, V, = tsvd(Δ.Z; alg=DEFAULT_SVD_ALG) + U, S, V, = tsvd(Δ.Z; alg=default_svd_alg(Δ.Z)) Base.setfield!(Δ, :U, U) Base.setfield!(Δ, :S, S) Base.setfield!(Δ, :V, V) @@ -173,12 +170,12 @@ while the local tangent vector along the retraction curve is `Z′ = - W * V' * sin(α*S) * S * V + U * cos(α * S) * S * V'`. """ -function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=DEFAULT_SVD_ALG) +function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=nothing) W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point")) U, S, V = Δ.U, Δ.S, Δ.V WVd = W * V' sSV, cSV = _sincosSV(α, S, V) # sin(S)*V, cos(S)*V - W′ = projectisometric!(WVd * cSV + U * sSV; alg=alg) + W′ = projectisometric!(WVd * cSV + U * sSV) sSSV = _lmul!(S, sSV) # sin(S)*S*V cSSV = _lmul!(S, cSV) # cos(S)*S*V Z′ = projectcomplement!(-WVd * sSSV + U * cSSV, W′) @@ -194,16 +191,16 @@ This is done by solving the equation `Wold * V' * cos(S) * V + U * sin(S) * V = for the isometries `U`, `V`, and `Y`, and the diagonal matrix `S`, and returning `Z = U * S * V` and `Y`. """ -function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=DEFAULT_SVD_ALG) +function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothing) space(Wold) == space(Wnew) || throw(SpaceMismatch()) WodWn = Wold' * Wnew # V' * cos(S) * V * Y Wneworth = Wnew - Wold * WodWn - Vd, cS, VY = tsvd!(WodWn) + Vd, cS, VY = tsvd!(WodWn; alg=default_svd_alg(WodWn)) Scmplx = acos(cS) # acos always returns a complex TensorMap. We cast back to real if possible. S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx UsS = Wneworth * VY' # U * sin(S) # should be in polar decomposition form - U = projectisometric!(UsS; alg=alg) + U = projectisometric!(UsS) Y = Vd * VY V = Vd' Z = Grassmann.GrassmannTangent(Wold, U * S * V) @@ -217,9 +214,9 @@ Return the unitary Y such that V*Y and W are "in the same Grassmann gauge" (tech from fibre bundles: in the same section), such that they can be related by a Grassmann retraction. """ -function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap; alg=DEFAULT_SVD_ALG) +function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap; alg=nothing) space(W) == space(V) || throw(SpaceMismatch()) - return projectisometric!(V' * W; alg=alg) + return projectisometric!(V' * W) end function transport!(Θ::GrassmannTangent, W::AbstractTensorMap, Δ::GrassmannTangent, α, W′;