Skip to content

Commit

Permalink
Implement code suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 2, 2024
1 parent fbe4db5 commit 67f0c7f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
7 changes: 5 additions & 2 deletions src/TensorKitManifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ checkbase(x, y, z, args...) = checkbase(checkbase(x, y), z, args...)
# the machine epsilon for the elements of an object X, name inspired from eltype
scalareps(X) = eps(real(scalartype(X)))

# default SVD algorithm used in the algorithms
default_svd_alg(::AbstractTensorMap) = TensorKit.SVD()

function isisometry(W::AbstractTensorMap; tol=10 * scalareps(W))
WdW = W' * W
s = zero(float(real(scalartype(W))))
Expand Down Expand Up @@ -61,7 +64,7 @@ end

struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm
end
function projectisometric!(W::AbstractTensorMap; alg=Polar())
function projectisometric!(W::AbstractTensorMap; alg=default_svd_alg(W))
if alg isa TensorKit.Polar || alg isa TensorKit.SDD
foreach(blocks(W)) do (c, b)
return _polarsdd!(b)
Expand Down Expand Up @@ -98,7 +101,7 @@ projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W))
projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W))

function projectisometric(W::AbstractTensorMap;
alg::TensorKit.OrthogonalFactorizationAlgorithm=Polar())
alg=default_svd_alg(W))
return projectisometric!(copy(W); alg=alg)
end
function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap,
Expand Down
23 changes: 10 additions & 13 deletions src/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ module Grassmann
using TensorKit
using TensorKit: similarstoragetype, SectorDict
using ..TensorKitManifolds: projecthermitian!, projectantihermitian!,
projectisometric!, projectcomplement!, PolarNewton
projectisometric!, projectcomplement!, PolarNewton,
default_svd_alg
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!

# Default algorithm used in all of the SVD-based methods
# Picking SVD instead of SDD here, as it seems SDD has stability issues for tensors that are already close to isometry.
const DEFAULT_SVD_ALG = SVD()

# special type to store tangent vectors using Z
# add SVD of Z = U*S*V upon first creation
mutable struct GrassmannTangent{T<:AbstractTensorMap,
Expand Down Expand Up @@ -60,7 +57,7 @@ function Base.getproperty(Δ::GrassmannTangent, sym::Symbol)
elseif sym (:U, :S, :V)
v = Base.getfield(Δ, sym)
v !== nothing && return v
U, S, V, = tsvd.Z; alg=DEFAULT_SVD_ALG)
U, S, V, = tsvd.Z; alg=default_svd_alg.Z))
Base.setfield!(Δ, :U, U)
Base.setfield!(Δ, :S, S)
Base.setfield!(Δ, :V, V)
Expand Down Expand Up @@ -173,12 +170,12 @@ while the local tangent vector along the retraction curve is
`Z′ = - W * V' * sin(α*S) * S * V + U * cos(α * S) * S * V'`.
"""
function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=DEFAULT_SVD_ALG)
function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg=nothing)
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
U, S, V = Δ.U, Δ.S, Δ.V
WVd = W * V'
sSV, cSV = _sincosSV(α, S, V) # sin(S)*V, cos(S)*V
W′ = projectisometric!(WVd * cSV + U * sSV; alg=alg)
W′ = projectisometric!(WVd * cSV + U * sSV)
sSSV = _lmul!(S, sSV) # sin(S)*S*V
cSSV = _lmul!(S, cSV) # cos(S)*S*V
Z′ = projectcomplement!(-WVd * sSSV + U * cSSV, W′)
Expand All @@ -194,16 +191,16 @@ This is done by solving the equation `Wold * V' * cos(S) * V + U * sin(S) * V =
for the isometries `U`, `V`, and `Y`, and the diagonal matrix `S`, and returning
`Z = U * S * V` and `Y`.
"""
function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=DEFAULT_SVD_ALG)
function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothing)
space(Wold) == space(Wnew) || throw(SpaceMismatch())
WodWn = Wold' * Wnew # V' * cos(S) * V * Y
Wneworth = Wnew - Wold * WodWn
Vd, cS, VY = tsvd!(WodWn)
Vd, cS, VY = tsvd!(WodWn; alg=default_svd_alg(WodWn))
Scmplx = acos(cS)
# acos always returns a complex TensorMap. We cast back to real if possible.
S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx
UsS = Wneworth * VY' # U * sin(S) # should be in polar decomposition form
U = projectisometric!(UsS; alg=alg)
U = projectisometric!(UsS)
Y = Vd * VY
V = Vd'
Z = Grassmann.GrassmannTangent(Wold, U * S * V)
Expand All @@ -217,9 +214,9 @@ Return the unitary Y such that V*Y and W are "in the same Grassmann gauge" (tech
from fibre bundles: in the same section), such that they can be related by a Grassmann
retraction.
"""
function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap; alg=DEFAULT_SVD_ALG)
function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap; alg=nothing)
space(W) == space(V) || throw(SpaceMismatch())
return projectisometric!(V' * W; alg=alg)
return projectisometric!(V' * W)
end

function transport!::GrassmannTangent, W::AbstractTensorMap, Δ::GrassmannTangent, α, W′;
Expand Down

0 comments on commit 67f0c7f

Please sign in to comment.