Skip to content

Commit

Permalink
Add svdsolve AD rule (#84)
Browse files Browse the repository at this point in the history
* Add untested svdsolve rrule

* Fix typos

* Add svdsolve rrule to extension folder

* Delete src/adrules/svdsolve.jl

---------

Co-authored-by: Jutho <[email protected]>
  • Loading branch information
pbrehmer and Jutho authored May 13, 2024
1 parent 4e8e1fa commit c4f6a48
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ using VectorInterface
include("utilities.jl")
include("linsolve.jl")
include("eigsolve.jl")
include("svdsolve.jl")

end # module
89 changes: 89 additions & 0 deletions ext/KrylovKitChainRulesCoreExt/svdsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Reverse rule adopted from tsvd! rrule as found in TensorKit.jl
function ChainRulesCore.rrule(::typeof(svdsolve), A, x₀, howmany::Int, which::Symbol,

Check warning on line 2 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L2

Added line #L2 was not covered by tests
alg::GKL)
val, lvec, rvec, info = svdsolve(A, x₀, howmany, which, alg)

Check warning on line 4 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L4

Added line #L4 was not covered by tests

function svdsolve_pullback((Δval, Δlvec, Δrvec, Δinfo))

Check warning on line 6 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L6

Added line #L6 was not covered by tests
# TODO: These type conversion should be probably handled differently
U = hcat(lvec...)
S = diagm(val)
V = copy(hcat(rvec...)')
ΔU = Δlvec isa ZeroTangent ? Δlvec : hcat(Δlvec...)
ΔS = Δval isa ZeroTangent ? Δval : diagm(Δval)
ΔV = Δrvec isa ZeroTangent ? Δrvec : hcat(Δrvec...)

Check warning on line 13 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L8-L13

Added lines #L8 - L13 were not covered by tests

∂A = truncsvd_rrule(A, U, S, V, ΔU, ΔS, ΔV)
return NoTangent(), ∂A, ZeroTangent(), NoTangent(), NoTangent(), NoTangent()

Check warning on line 16 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end

return (val, lvec, rvec, info), svdsolve_pullback

Check warning on line 19 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L19

Added line #L19 was not covered by tests
end

# SVD adjoint with correct truncation contribution
# as presented in: https://arxiv.org/abs/2311.11894
function truncsvd_rrule(A,

Check warning on line 24 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L24

Added line #L24 was not covered by tests
U,
S,
V,
ΔU,
ΔS,
ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4),)
Ad = copy(A')
tol = atol > 0 ? atol : rtol * S[1, 1]
S⁻¹ = pinv(S; atol=tol)

Check warning on line 35 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L33-L35

Added lines #L33 - L35 were not covered by tests

# Compute possibly divergent F terms
F = similar(S)
@inbounds for i in axes(F, 1), j in axes(F, 2)
F[i, j] = if i == j
zero(T)

Check warning on line 41 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L38-L41

Added lines #L38 - L41 were not covered by tests
else
sᵢ, sⱼ = S[i, i], S[j, j]
Δs = abs(sⱼ - sᵢ) < tol ? tol : sⱼ^2 - sᵢ^2
1 / Δs

Check warning on line 45 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L43-L45

Added lines #L43 - L45 were not covered by tests
end
end

Check warning on line 47 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L47

Added line #L47 was not covered by tests

# dS contribution
term = ΔS isa ZeroTangent ? ΔS : Diagonal(real.(ΔS))

Check warning on line 50 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L50

Added line #L50 was not covered by tests

# dU₁ and dV₁ off-diagonal contribution
J = F .* (U' * ΔU)
term += (J + J') * S
VΔV = (V * ΔV')
K = F .* VΔV
term += S * (K + K')

Check warning on line 57 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L53-L57

Added lines #L53 - L57 were not covered by tests

# dV₁ diagonal contribution (diagonal of dU₁ is gauged away)
if scalartype(U) <: Complex && !(ΔV isa ZeroTangent) && !(ΔU isa ZeroTangent)
L = Diagonal(VΔV)
term += 0.5 * S⁻¹ * (L' - L)

Check warning on line 62 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L60-L62

Added lines #L60 - L62 were not covered by tests
end
ΔA = U * term * V

Check warning on line 64 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L64

Added line #L64 was not covered by tests

# Projector contribution for non-square A and dU₂ and dV₂
UUd = U * U'
VdV = V' * V
Uproj = one(UUd) - UUd
Vproj = one(VdV) - VdV

Check warning on line 70 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L67-L70

Added lines #L67 - L70 were not covered by tests

# Truncation contribution from dU₂ and dV₂
function svdlinprob(v) # Left-preconditioned linear problem
Γ1 = v[1] - S⁻¹ * v[2] * Vproj * Ad
Γ2 = v[2] - S⁻¹ * v[1] * Uproj * A
return (Γ1, Γ2)

Check warning on line 76 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L73-L76

Added lines #L73 - L76 were not covered by tests
end
if ΔU isa ZeroTangent && ΔV isa ZeroTangent
m, k, n = size(U, 1), size(U, 2), size(V, 2)
y = (zeros(scalartype(A), k * m), zeros(scalartype(A), k * n))
γ, = linsolve(svdlinprob, y; rtol=eps(real(scalartype(A))))

Check warning on line 81 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L78-L81

Added lines #L78 - L81 were not covered by tests
else
y = (S⁻¹ * ΔU' * Uproj, S⁻¹ * ΔV * Vproj)
γ, = linsolve(svdlinprob, y; rtol=eps(real(scalartype(A))))

Check warning on line 84 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L83-L84

Added lines #L83 - L84 were not covered by tests
end
ΔA += Uproj * γ[1]' * V + U * γ[2] * Vproj

Check warning on line 86 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L86

Added line #L86 was not covered by tests

return ΔA

Check warning on line 88 in ext/KrylovKitChainRulesCoreExt/svdsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/svdsolve.jl#L88

Added line #L88 was not covered by tests
end

0 comments on commit c4f6a48

Please sign in to comment.