Skip to content

Commit

Permalink
Use KrylovKit.linsolve for truncation linear problem, make loss funct…
Browse files Browse the repository at this point in the history
…ion differentiable
  • Loading branch information
pbrehmer committed Mar 5, 2024
1 parent c390bec commit 51507ce
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 31 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502"
Expand Down
34 changes: 19 additions & 15 deletions examples/test_svd_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra
using TensorKit
using ChainRulesCore, Zygote
using ChainRulesCore, ChainRulesTestUtils, Zygote
using PEPSKit

# Non-proper truncated SVD with outdated adjoint
Expand Down Expand Up @@ -45,9 +45,6 @@ function oldsvd_rev(
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4),
)
S = diagm(S)
V = copy(V')

tol = atol > 0 ? atol : rtol * S[1, 1]
F = PEPSKit.invert_S²(S, tol; εbroad) # Includes Lorentzian broadening
S⁻¹ = pinv(S; atol=tol)
Expand All @@ -74,26 +71,33 @@ function oldsvd_rev(
VdV = V' * V
Uproj = one(UUd) - UUd
Vproj = one(VdV) - VdV
ΔA += Uproj * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * Vproj # Old wrong stuff
ΔA += Uproj * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * Vproj # Wrong truncation contribution

return ΔA
end

# Loss function taking the nfirst first singular vectors into account
function nfirst_loss(A, svdfunc; nfirst=1)
# Gauge-invariant loss function
function lossfun(A, svdfunc)
U, _, V = svdfunc(A)
U = convert(Array, U)
V = convert(Array, V)
return real(sum([U[i, i] * V[i, i] for i in 1:nfirst]))
# return real(sum((U * V).data)) # TODO: code up sum for AbstractTensorMap with rrule
return real(tr(U * V)) # trace only allows for m=n
end

m, n = 30, 20
m, n = 30, 30
dtype = ComplexF64
χ = 15
χ = 20
r = TensorMap(randn, dtype, ℂ^m ^n)

ltensorkit, gtensorkit = withgradient(A -> nfirst_loss(A, x -> oldsvd(x, χ); nfirst=3), r)
litersvd, gitersvd = withgradient(A -> nfirst_loss(A, x -> itersvd(x, χ); nfirst=3), r)
println("Non-truncated SVD")
ltensorkit, gtensorkit = withgradient(A -> lossfun(A, x -> oldsvd(x, min(m, n))), r)
litersvd, gitersvd = withgradient(A -> lossfun(A, x -> itersvd(x, min(m, n))), r)
@show ltensorkit litersvd
@show norm(gtensorkit[1] - gitersvd[1])

println("\nTruncated SVD to χ=:")
ltensorkit, gtensorkit = withgradient(A -> lossfun(A, x -> oldsvd(x, χ)), r)
litersvd, gitersvd = withgradient(A -> lossfun(A, x -> itersvd(x, χ)), r)
@show ltensorkit litersvd
@show gtensorkit gitersvd
@show norm(gtensorkit[1] - gitersvd[1])

# TODO: Finite-difference check via test_rrule
28 changes: 12 additions & 16 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,24 @@ function itersvd_rev(
dimγ = k * m # Vectorized dimension of γ-matrix

# Truncation contribution from dU₂ and dV₂
# TODO: Use KrylovKit instead of IterativeSolvers
Sop = LinearMap(k * m + k * n) do v # Left-preconditioned linear problem
γ = reshape(@view(v[1:dimγ]), (k, m))
γd = reshape(@view(v[(dimγ + 1):end]), (k, n))
Γ1 = γ - S⁻¹ * γd * Vproj * Ad
Γ2 = γd - S⁻¹ * γ * Uproj * A
vcat(reshape(Γ1, :), reshape(Γ2, :))
function svdlinprob(v) # Left-preconditioned linear problem
γ1 = reshape(@view(v[1:dimγ]), (k, m))
γ2 = reshape(@view(v[(dimγ + 1):end]), (k, n))
Γ1 = γ1 - S⁻¹ * γ2 * Vproj * Ad
Γ2 = γ2 - S⁻¹ * γ1 * Uproj * A
return vcat(reshape(Γ1, :), reshape(Γ2, :))
end
if ΔU isa ZeroTangent && ΔV isa ZeroTangent
γ = gmres(Sop, zeros(eltype(A), k * m + k * n))
γ = linsolve(Sop, zeros(eltype(A), k * m + k * n))
else
# Explicit left-preconditioning
# Set relative tolerance to machine precision to converge SVD gradient error properly
γ = gmres(
Sop,
vcat(reshape(S⁻¹ * ΔU' * Uproj, :), reshape(S⁻¹ * ΔV * Vproj, :));
reltol=eps(real(eltype(A))),
)
y = vcat(reshape(S⁻¹ * ΔU' * Uproj, :), reshape(S⁻¹ * ΔV * Vproj, :))
γ, = linsolve(svdlinprob, y; rtol=eps(real(eltype(A))))
end
γA = reshape(@view(γ[1:dimγ]), k, m)
γAd = reshape(@view(γ[(dimγ + 1):end]), k, n)
ΔA += Uproj * γA' * V + U * γAd * Vproj
γA1 = reshape(@view(γ[1:dimγ]), k, m)
γA2 = reshape(@view(γ[(dimγ + 1):end]), k, n)
ΔA += Uproj * γA1' * V + U * γA2 * Vproj

return ΔA
end

0 comments on commit 51507ce

Please sign in to comment.