From 5c0b6f53097082b282a65b4e53d15d3e418aac01 Mon Sep 17 00:00:00 2001 From: Jutho Date: Fri, 17 Nov 2023 10:07:30 +0100 Subject: [PATCH] improve svd rrule and tests --- ext/TensorKitChainRulesCoreExt.jl | 228 +++++++++++++++++++++--------- test/ad.jl | 101 +++++++++++-- 2 files changed, 251 insertions(+), 78 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index a5342483..d926afbf 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -134,85 +134,183 @@ end # Factorizations # -------------- - -function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; kwargs...) - U, S, V, ϵ = tsvd(t; kwargs...) - - function tsvd!_pullback((ΔU, ΔS, ΔV, Δϵ)) - ∂t = similar(t) - for (c, b) in blocks(∂t) - copyto!(b, - svd_rev(block(U, c), block(S, c), block(V, c), - block(ΔU, c), block(ΔS, c), block(ΔV, c))) +function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; + trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(), + p::Real=2, + alg::Union{TensorKit.SVD,TensorKit.SDD}=TensorKit.SDD()) + U, Σ, V, ϵ = tsvd(t; trunc=TensorKit.NoTruncation(), p=p, alg=alg) + + if !(trunc isa TensorKit.NoTruncation) && !isempty(blocksectors(t)) + Σddata = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(Σ)) + dims = TensorKit.SectorDict(c => length(b) for (c, b) in Σddata) + Σddata, ϵ = TensorKit._truncate!(Σddata, trunc, p) + Udata′, Σddata′, Vdata′, dims′ = TensorKit._implement_svdtruncation!(t, + copy(U.data), + Σddata, + copy(V.data), + dims) + W = spacetype(t)(dims′) + if W ≅ domain(Σ) + W = domain(Σ) end - - return NoTangent(), ∂t + U′, Σ′, V′ = TensorKit._create_svdtensors(t, Udata′, Σddata′, Vdata′, W) + else + U′, Σ′, V′ = U, Σ, V end - return (U, S, V, ϵ), tsvd!_pullback -end - -""" - svd_rev(U, S, V, ΔU, ΔS, ΔV; tol=eps(real(scalartype(Σ)))^(4 / 5)) - -Implements the following back propagation formula for the SVD: - -```math -ΔA = UΔSV' + U(J + J')SV' + US(K + K')V' + \\frac{1}{2}US^{-1}(L' - L)V'\\ -J = F ∘ (U'ΔU), \\qquad K = F ∘ (V'ΔV), \\qquad L = I ∘ (V'ΔV)\\ -F_{i ≠ j} = \\frac{1}{s_j^2 - s_i^2}\\ -F_{ii} = 0 -``` - -# References + function tsvd!_pullback((ΔU, ΔΣ, ΔV, Δϵ)) + Δt = similar(t) + for (c, b) in blocks(Δt) + Uc, Σc, Vc = block(U, c), block(Σ, c), block(V, c) + ΔUc, ΔΣc, ΔVc = block(ΔU, c), block(ΔΣ, c), block(ΔV, c) + Σdc = view(Σc, diagind(Σc)) + ΔΣdc = view(ΔΣc, diagind(ΔΣc)) + copyto!(b, svd_pullback(Uc, Σdc, Vc, ΔUc, ΔΣdc, ΔVc)) + end + return NoTangent(), Δt + end -Wan, Zhou-Quan, and Shi-Xin Zhang. 2019. “Automatic Differentiation for Complex Valued SVD.” https://doi.org/10.48550/ARXIV.1909.02659. -""" -function svd_rev(U::AbstractMatrix, S::AbstractMatrix, V::AbstractMatrix, ΔU, ΔS, ΔV; - atol::Real=0, - rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4)) - # project out gauge invariance dependence? - # ΔU * U + ΔV * V' = 0 + return (U′, Σ′, V′, ϵ), tsvd!_pullback +end + +# SVD_pullback: pullback implementation for general (possibly truncated) SVD +# +# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as +# cotangent ΔU, ΔS, ΔVd variables of truncated SVD +# +# Checks whether the cotangent variables are such that they would couple to gauge-dependent +# degrees of freedom (phases of singular vectors), and prints a warning if this is the case +# +# An implementation that only uses U, S, and Vd from truncated SVD is also possible, but +# requires solving a Sylvester equation, which does not seem to be supported on GPUs. +# +# Other implementation considerations for GPU compatibility: +# no scalar indexing, lots of broadcasting and views +# +safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a) +function svd_pullback(U::AbstractMatrix, S::AbstractVector, Vd::AbstractMatrix, ΔU, ΔS, ΔVd; + atol::Real=0, + rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4)) + + # Basic size checks and determination + m, n = size(U, 1), size(Vd, 2) + size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch()) + ΔU isa AbstractZero && ΔVd isa AbstractZero && ΔS isa AbstractZero && + return ZeroTangent() + p = -1 + if !(ΔU isa AbstractZero) + m == size(ΔU, 1) || throw(DimensionMismatch()) + p = size(ΔU, 2) + end + if !(ΔVd isa AbstractZero) + n == size(ΔVd, 2) || throw(DimensionMismatch()) + if p == -1 + p = size(ΔVd, 1) + else + p == size(ΔVd, 1) || throw(DimensionMismatch()) + end + end + if !(ΔS isa AbstractZero) + if ΔS isa AbstractMatrix + ΔSr = real(diag(ΔS)) + else # ΔS isa AbstractVector + ΔSr = real(ΔS) + end + if p == -1 + p = length(ΔSr) + else + p == length(ΔSr) || throw(DimensionMismatch()) + end + end + Up = view(U, :, 1:p) + Vp = view(Vd, 1:p, :)' + Sp = view(S, 1:p) + # tolerance and rank tol = atol > 0 ? atol : rtol * S[1, 1] - F = _invert_S²(S, tol) - S⁻¹ = pinv(S; atol=tol) - - term = ΔS isa ZeroTangent ? ΔS : Diagonal(diag(ΔS)) + r = findlast(>=(tol), S) + + # compute antihermitian part of projection of ΔU and ΔV onto U and V + # also already subtract this projection from ΔU and ΔV + if !(ΔU isa AbstractZero) + UΔU = Up' * ΔU + aUΔU = rmul!(UΔU - UΔU', 1 / 2) + if m > p + ΔU -= Up * UΔU + end + else + aUΔU = fill!(similar(U, (p, p)), 0) + end + if !(ΔVd isa AbstractZero) + VΔV = Vp' * ΔVd' + aVΔV = rmul!(VΔV - VΔV', 1 / 2) + if n > p + ΔVd -= VΔV' * Vp' + end + else + aVΔV = fill!(similar(V, (p, p)), 0) + end - J = F .* (U' * ΔU) - term += (J + J') * S - VΔV = (V * ΔV') - K = F .* VΔV - term += S * (K + K') + # check whether cotangents arise from gauge-invariance objective function + mask = abs.(Sp' .- Sp) .< tol + gaugepart = view(aUΔU, mask) + view(aVΔV, mask) + norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + if p > r + rprange = (r + 1):p + norm(view(aUΔU, rprange, rprange), Inf) < tol || + @warn "cotangents sensitive to gauge choice" + norm(view(aVΔV, rprange, rprange), Inf) < tol || + @warn "cotangents sensitive to gauge choice" + end - if scalartype(U) <: Complex && !(ΔV isa ZeroTangent) && !(ΔU isa ZeroTangent) - L = LinearAlgebra.Diagonal(diag(VΔV)) - term += 0.5 * S⁻¹ * (L' - L) + UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+ + (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol) + if !(ΔS isa ZeroTangent) + UdΔAV[diagind(UdΔAV)] .+= ΔSr end + ΔA = Up * UdΔAV * Vp' + + if r > p # contribution from truncation + Ur = view(U, :, (p + 1):r) + Vr = view(Vd, (p + 1):r, :)' + Sr = view(S, (p + 1):r) + + if !(ΔU isa AbstractZero) + UrΔU = Ur' * ΔU + if m > r + ΔU -= Ur * UrΔU # subtract this part from ΔU + end + else + UrΔU = fill!(similar(U, (r - p, p)), 0) + end + if !(ΔVd isa AbstractZero) + VrΔV = Vr' * ΔVd' + if n > r + ΔVd -= VrΔV' * Vr' # subtract this part from ΔV + end + else + VrΔV = fill!(similar(V, (r - p, p)), 0) + end - ΔA = U * term * V + X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+ + (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) + Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .- + (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) - if size(U, 1) != size(V, 2) - UUd = U * U' - VdV = V' * V - ΔA += (one(UUd) - UUd) * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * (one(VdV) - VdV) + # ΔA += Ur * X * Vp' + Up * Y' * Vr' + mul!(ΔA, Ur, X * Vp', 1, 1) + mul!(ΔA, Up * Y', Vr', 1, 1) end - return ΔA -end - -function _invert_S²(S::AbstractMatrix{T}, tol::Real) where {T<:Real} - F = similar(S) - @inbounds for i in axes(F, 1), j in axes(F, 2) - F[i, j] = if i == j - zero(T) - else - sᵢ, sⱼ = S[i, i], S[j, j] - 1 / (abs(sⱼ - sᵢ) < tol ? tol : sⱼ^2 - sᵢ^2) - end + if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)] + # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp' + mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1) end - return F + if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)] + # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd) + mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1) + end + return ΔA end function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) diff --git a/test/ad.jl b/test/ad.jl index 3f87afca..a38776b2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -65,8 +65,8 @@ precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8 # rrules for functions that destroy inputs # ---------------------------------------- -function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...) - return ChainRulesCore.rrule(tsvd!, args...) +function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...) + return ChainRulesCore.rrule(tsvd!, args...; kwargs...) end function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...) return ChainRulesCore.rrule(leftorth!, args...; kwargs...) @@ -131,8 +131,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)')) -@testset "Automatic Differentiation ($(eltype(V)))" verbose = true for V in Vlist - @testset "Basic Linear Algebra ($T)" for T in (Float64, ComplexF64) +@testset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in + Vlist + @testset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)) @@ -149,7 +150,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(permute, A, ((1, 3, 2), (5, 4))) end - @testset "Linear Algebra part II ($T)" for T in (Float64, ComplexF64) + @testset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64) for i in 1:3 E = TensorMap(randn, T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) test_rrule(LinearAlgebra.tr, E) @@ -160,7 +161,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(LinearAlgebra.norm, A, 2) end - @testset "TensorOperations ($T)" for T in (Float64, ComplexF64) + @testset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) atol = precision(T) rtol = precision(T) @@ -224,7 +225,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), end end - @testset "Factorizations ($T)" for T in (Float64, ComplexF64) + @testset "Factorizations with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)') C = TensorMap(randn, T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) @@ -242,12 +243,86 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(rightorth, C; fkwargs=(; alg=alg), atol) end - # Complex-valued SVD tests are incompatible with finite differencing, - # because U and V are not unique. - if T <: Real - test_rrule(tsvd, A; atol) - test_rrule(tsvd, B; atol) - test_rrule(tsvd, C; atol) + let (U, S, V, ϵ) = tsvd(A) + ΔU = TensorMap(randn, scalartype(U), space(U)) + ΔS = TensorMap(randn, scalartype(S), space(S)) + ΔV = TensorMap(randn, scalartype(V), space(V)) + if T <: Complex # remove gauge dependent components + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + end + test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) + + allS = mapreduce(x -> diag(x[2]), vcat, blocks(S)) + truncval = (maximum(allS) + minimum(allS)) / 2 + U, S, V, ϵ = tsvd(A; trunc=truncbelow(truncval)) + ΔU = TensorMap(randn, scalartype(U), space(U)) + ΔS = TensorMap(randn, scalartype(S), space(S)) + ΔV = TensorMap(randn, scalartype(V), space(V)) + if T <: Complex # remove gauge dependent components + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + end + test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), + fkwargs=(; trunc=truncbelow(truncval))) + end + + let (U, S, V, ϵ) = tsvd(B) + ΔU = TensorMap(randn, scalartype(U), space(U)) + ΔS = TensorMap(randn, scalartype(S), space(S)) + ΔV = TensorMap(randn, scalartype(V), space(V)) + if T <: Complex # remove gauge dependent components + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + end + test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) + + Vtrunc = spacetype(S)(c => ceil(size(b, 1) / 2) for (c, b) in blocks(S)) + + U, S, V, ϵ = tsvd(B; trunc=truncspace(Vtrunc)) + ΔU = TensorMap(randn, scalartype(U), space(U)) + ΔS = TensorMap(randn, scalartype(S), space(S)) + ΔV = TensorMap(randn, scalartype(V), space(V)) + if T <: Complex # remove gauge dependent components + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + end + test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), + fkwargs=(; trunc=truncspace(Vtrunc))) + end + + let (U, S, V, ϵ) = tsvd(C) + ΔU = TensorMap(randn, scalartype(U), space(U)) + ΔS = TensorMap(randn, scalartype(S), space(S)) + ΔV = TensorMap(randn, scalartype(V), space(V)) + if T <: Complex # remove gauge dependent components + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + end + test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) + + U, S, V, ϵ = tsvd(C; trunc=truncdim(2)) + ΔU = TensorMap(randn, scalartype(U), space(U)) + ΔS = TensorMap(randn, scalartype(S), space(S)) + ΔV = TensorMap(randn, scalartype(V), space(V)) + if T <: Complex # remove gauge dependent components + gaugepart = U' * ΔU + V * ΔV' + for (c, b) in blocks(gaugepart) + mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + end + end + test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), + fkwargs=(; trunc=truncdim(2))) end end end