diff --git a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl index 97c61e7..a44f8a6 100644 --- a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl +++ b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl @@ -5,6 +5,7 @@ using ChainRulesCore using LinearAlgebra using VectorInterface +include("utilities.jl") include("linsolve.jl") include("eigsolve.jl") diff --git a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl index fd695f3..e0bf746 100644 --- a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl +++ b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl @@ -1,249 +1,300 @@ -function ChainRulesCore.rrule(::typeof(eigsolve), - A::AbstractMatrix, +function ChainRulesCore.rrule(config::RuleConfig, + ::typeof(eigsolve), + f, x₀, howmany, which, - alg) - (vals, vecs, info) = eigsolve(A, x₀, howmany, which, alg) - project_A = ProjectTo(A) - T = eltype(vecs[1]) # will be real for real symmetric problems and complex otherwise + alg_primal; + alg_rrule=Arnoldi(; tol=alg_primal.tol, + krylovdim=alg_primal.krylovdim, + maxiter=alg_primal.maxiter, + eager=alg_primal.eager, + orth=alg_primal.orth)) + (vals, vecs, info) = eigsolve(f, x₀, howmany, which, alg_primal) + T, fᴴ, construct∂f = _prepare_inputs(config, f, vecs, alg_primal) function eigsolve_pullback(ΔX) - _Δvals = unthunk(ΔX[1]) - _Δvecs = unthunk(ΔX[2]) - ∂self = NoTangent() ∂x₀ = ZeroTangent() ∂howmany = NoTangent() ∂which = NoTangent() ∂alg = NoTangent() - if _Δvals isa AbstractZero && _Δvecs isa AbstractZero - ∂A = ZeroTangent() - return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg - end + _Δvals = unthunk(ΔX[1]) + _Δvecs = unthunk(ΔX[2]) + + n = 0 + while true + if !(_Δvals isa AbstractZero) && + any(!iszero, view(_Δvals, (n + 1):length(_Δvals))) + n = n + 1 + continue + end + if !(_Δvecs isa AbstractZero) && + any(!Base.Fix2(isa, AbstractZero), view(_Δvecs, (n + 1):length(_Δvecs))) + n = n + 1 + continue + end + break + end + @assert n <= length(vals) + if n == 0 + ∂f = ZeroTangent() + return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg + end if _Δvals isa AbstractZero - Δvals = fill(NoTangent(), length(Δvecs)) + Δvals = fill(zero(vals[1]), n) else - Δvals = _Δvals + @assert length(_Δvals) >= n + Δvals = view(_Δvals, 1:n) end if _Δvecs isa AbstractZero - Δvecs = fill(NoTangent(), length(Δvals)) + Δvecs = fill(ZeroTangent(), n) else - Δvecs = _Δvecs - end - - @assert length(Δvals) == length(Δvecs) - @assert length(Δvals) <= length(vals) - - # Determine algorithm to solve linear problem - # TODO: Is there a better choice? Should we make this user configurable? - linalg = GMRES(; - tol=alg.tol, - krylovdim=alg.krylovdim, - maxiter=alg.maxiter, - orth=alg.orth) - - ws = similar(vecs, length(Δvecs)) - for i in 1:length(Δvecs) - Δλ = Δvals[i] - Δv = Δvecs[i] - λ = vals[i] - v = vecs[i] - - # First threat special cases - if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution - ws[i] = Δv # some kind of zero - continue - end - if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution - ws[i] = Δλ * v - continue - end + @assert length(_Δvecs) >= n + Δvecs = view(_Δvecs, 1:n) + end - # General case : - if isa(Δv, AbstractZero) - b = RecursiveVec(zero(T) * v, T[Δλ]) - else - @assert isa(Δv, typeof(v)) - b = RecursiveVec(Δv, T[Δλ]) - end + ws = compute_eigsolve_pullback_data(Δvals, Δvecs, view(vals, 1:n), view(vecs, 1:n), + info, which, fᴴ, T, alg_primal, alg_rrule) + # alg_rrule2 = Arnoldi(; tol=alg_rrule.tol, krylovdim=alg_rrule.krylovdim, maxiter=alg_rrule.maxiter, orth=alg_rrule.orth) + # ws2 = compute_eigsolve_pullback_data(Δvals, Δvecs, view(vals, 1:n), view(vecs, 1:n), info, which, fᴴ, T, alg_primal, alg_rrule2) + # for i = 1:n + # @show ws[i] + # @show ws2[i] + # end - if i > 1 && eltype(A) <: Real && - vals[i] == conj(vals[i - 1]) && Δvals[i] == conj(Δvals[i - 1]) && - vecs[i] == conj(vecs[i - 1]) && Δvecs[i] == conj(Δvecs[i - 1]) - ws[i] = conj(ws[i - 1]) - continue - end + ∂f = construct∂f(ws) + return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg + end + return (vals, vecs, info), eigsolve_pullback +end - w, reverse_info = let λ = λ, v = v, Aᴴ = A' - linsolve(b, zero(T) * b, linalg) do x - x1, x2 = x - γ = 1 - # γ can be chosen freely and does not affect the solution theoretically - # The current choice guarantees that the extended matrix is Hermitian if A is - # TODO: is this the best choice in all cases? - y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, A' * x1)) - y2 = T[-dot(v, x1)] - return RecursiveVec(y1, y2) - end - end - if info.converged >= i && reverse_info.converged == 0 - @warn "The cotangent linear problem did not converge, whereas the primal eigenvalue problem did." - end - ws[i] = w[1] +function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T, + alg_primal, alg_rrule::Union{GMRES,BiCGStab}) + ws = similar(vecs, length(Δvecs)) + @inbounds for i in 1:length(Δvecs) + Δλ = Δvals[i] + Δv = Δvecs[i] + λ = vals[i] + v = vecs[i] + + # First threat special cases + if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution + ws[i] = zerovector(v) + continue + end + if isa(Δv, AbstractZero) && isa(alg_primal, Lanczos) # simple contribution + ws[i] = scale(v, Δλ) + continue end - if A isa StridedMatrix - ∂A = InplaceableThunk(Ā -> _buildĀ!(Ā, ws, vecs), - @thunk(_buildĀ!(zero(A), ws, vecs))) + # General case : + + # for the case where `f` is a real matrix, we can expect the following simplication + # TODO: can we implement this within our general approach, or generalise this to also + # cover the case where `f` is a function? + # if i > 1 && eltype(A) <: Real && + # vals[i] == conj(vals[i - 1]) && Δvals[i] == conj(Δvals[i - 1]) && + # vecs[i] == conj(vecs[i - 1]) && Δvecs[i] == conj(Δvecs[i - 1]) + # ws[i] = conj(ws[i - 1]) + # continue + # end + + if isa(Δv, AbstractZero) + b = (zerovector(v), convert(T, Δλ)) else - ∂A = @thunk(project_A(_buildĀ!(zero(A), ws, vecs))) + vdΔv = inner(v, Δv) + gaugeᵢ = abs(imag(vdΔv)) + gaugeᵢ < alg_primal.tol || + @warn "`eigsolve` cotangent for eigenvector $i is sensitive to gauge choice: (|gaugeᵢ| = $gaugeᵢ)" + Δv = add(Δv, v, -vdΔv) + b = (Δv, convert(T, Δλ)) + end + w, reverse_info = let λ = λ, v = v + linsolve(b, zerovector(b), alg_rrule) do x + x1, x2 = x + y1 = add!(add!(fᴴ(x1), x1, conj(λ), -1), v, x2) + y2 = inner(v, x1) + return (y1, y2) + end end - return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg + if info.converged >= i && reverse_info.converged == 0 + @warn "`eigsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did" + elseif abs(w[2]) > alg_rrule.tol + @warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $(w[2])" + end + ws[i] = w[1] end - return (vals, vecs, info), eigsolve_pullback + return ws end -function _buildĀ!(Ā, ws, vs) - for i in 1:length(ws) - w = ws[i] - v = vs[i] - if !(w isa AbstractZero) - if eltype(Ā) <: Real && eltype(w) <: Complex - mul!(Ā, _realview(w), _realview(v)', -1, 1) - mul!(Ā, _imagview(w), _imagview(v)', -1, 1) +function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T, + alg_primal::Arnoldi, alg_rrule::Arnoldi) + n = length(Δvecs) + G = zeros(T, n, n) + VdΔV = zeros(T, n, n) + for j in 1:n + for i in 1:n + if i < j + G[i, j] = conj(G[j, i]) + elseif i == j + G[i, i] = norm(vecs[i])^2 else - mul!(Ā, w, v', -1, 1) + G[i, j] = inner(vecs[i], vecs[j]) + end + if !(Δvecs[j] isa AbstractZero) + VdΔV[i, j] = inner(vecs[i], Δvecs[j]) end end end - return Ā -end -function _realview(v::AbstractVector{Complex{T}}) where {T} - return view(reinterpret(T, v), 2 * (1:length(v)) .- 1) -end -function _imagview(v::AbstractVector{Complex{T}}) where {T} - return view(reinterpret(T, v), 2 * (1:length(v))) -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, - ::typeof(eigsolve), - A::AbstractMatrix, - x₀, - howmany, - which, - alg) - return ChainRulesCore.rrule(eigsolve, A, x₀, howmany, which, alg) -end -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, - ::typeof(eigsolve), - f, - x₀, - howmany, - which, - alg) - (vals, vecs, info) = eigsolve(f, x₀, howmany, which, alg) - T = typeof(dot(vecs[1], vecs[1])) - f_pullbacks = map(x -> rrule_via_ad(config, f, x)[2], vecs) - - function eigsolve_pullback(ΔX) - _Δvals = unthunk(ΔX[1]) - _Δvecs = unthunk(ΔX[2]) + # components along subspace spanned by current eigenvectors + tol = alg_primal.tol + mask = abs.(transpose(vals) .- vals) .< tol + gaugepart = VdΔV[mask] - Diagonal(real(diag(VdΔV)))[mask] + Δgauge = norm(gaugepart, Inf) + Δgauge < tol || + @warn "`eigsolve` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + VdΔV′ = VdΔV - G * Diagonal(diag(VdΔV) ./ diag(G)) + aVdΔV = VdΔV′ .* conj.(safe_inv.(transpose(vals) .- vals, tol)) + for i in 1:n + aVdΔV[i, i] += Δvals[i] + end + Gc = cholesky!(G) + iGaVdΔV = Gc \ aVdΔV + iGVdΔV = Gc \ VdΔV - ∂self = NoTangent() - ∂x₀ = ZeroTangent() - ∂howmany = NoTangent() - ∂which = NoTangent() - ∂alg = NoTangent() - if _Δvals isa AbstractZero && _Δvecs isa AbstractZero - ∂A = ZeroTangent() - return (∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg) + zs = similar(Δvecs) + for i in 1:n + z = scale(vecs[1], iGaVdΔV[1, i]) + for j in 2:n + z = VectorInterface.add!!(z, vecs[j], iGaVdΔV[j, i]) end + zs[i] = z + end - if _Δvals isa AbstractZero - Δvals = fill(NoTangent(), howmany) - else - Δvals = _Δvals - end - if _Δvecs isa AbstractZero - Δvecs = fill(NoTangent(), howmany) - else - Δvecs = _Δvecs - end - - @assert length(Δvals) == length(Δvecs) - - # Determine algorithm to solve linear problem - # TODO: Is there a better choice? Should we make this user configurable? - linalg = GMRES(; - tol=alg.tol, - krylovdim=alg.krylovdim, - maxiter=alg.maxiter, - orth=alg.orth) - # linalg = BiCGStab(; - # tol = alg.tol, - # maxiter = alg.maxiter*alg.krylovdim, - # ) - - ws = similar(Δvecs) - for i in 1:length(Δvecs) - Δλ = Δvals[i] - Δv = Δvecs[i] - λ = vals[i] - v = vecs[i] - - # First threat special cases - if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution - ws[i] = Δv # some kind of zero - continue - end - if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution - ws[i] = Δλ * v - continue + # components in orthogonal subspace + sylvesterarg = similar(Δvecs) + for i in 1:n + y = fᴴ(zs[i]) + if !(Δvecs[i] isa AbstractZero) + y = VectorInterface.add!!(y, Δvecs[i], +1) + for j in 1:n + y = VectorInterface.add!!(y, vecs[j], -iGVdΔV[j, i]) end + end + sylvesterarg[i] = y + end - # General case : - if isa(Δv, AbstractZero) - b = RecursiveVec(zero(T) * v, T[-Δλ]) - else - @assert isa(Δv, typeof(v)) - b = RecursiveVec(-Δv, T[-Δλ]) + W₀ = (zerovector(vecs[1]), one.(vals)) + P = orthogonalcomplementprojector(vecs, n, Gc) + rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg + eigsolve(W₀, n, reverse_wich(which), alg_rrule) do W + w, x = W + w′ = fᴴ(P(w)) + @inbounds for i in 1:length(x) # length(x) = n but let us not use outer variables + w′ = VectorInterface.add!!(w′, ΔV[i], -x[i]) end + return (w′, conj.(vals) .* x) + end + end + if info.converged >= n && reverse_info.converged < n + @warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did" + end - # TODO: is there any analogy to this for general vector-like user types - # if i > 1 && eltype(A) <: Real && - # vals[i] == conj(vals[i-1]) && Δvals[i] == conj(Δvals[i-1]) && - # vecs[i] == conj(vecs[i-1]) && Δvecs[i] == conj(Δvecs[i-1]) - # - # ws[i] = conj(ws[i-1]) - # continue - # end - - w, reverse_info = let λ = λ, v = v, fᴴ = x -> f_pullbacks[i](x)[2] - linsolve(b, zero(T) * b, linalg) do x - x1, x2 = x - γ = 1 - # γ can be chosen freely and does not affect the solution theoretically - # The current choice guarantees that the extended matrix is Hermitian if A is - # TODO: is this the best choice in all cases? - y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, fᴴ(x1))) - y2 = T[-dot(v, x1)] - return RecursiveVec(y1, y2) - end + # cleanup and construct final result + ws = zs + tol = alg_rrule.tol + for i in 1:n + w, x = Ws[i] + _, ic = findmax(abs, x) + factor = 1 / x[ic] + x[ic] = zero(x[ic]) + error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic]))) + error < tol || + @warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error" + ws[ic] = VectorInterface.add!!(zs[ic], P(w), -factor) + end + return ws +end + +# several simplications happen in the case of a Hermitian eigenvalue problem +function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T, + alg_primal::Lanczos, alg_rrule::Arnoldi) + n = length(Δvecs) + VdΔV = zeros(T, n, n) + for j in 1:n + for i in 1:n + if !(Δvecs[j] isa AbstractZero) + VdΔV[i, j] = inner(vecs[i], Δvecs[j]) end - if info.converged >= i && reverse_info.converged == 0 - @warn "The cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did." + end + end + + # components along subspace spanned by current eigenvectors + tol = alg_primal.tol + aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2) + mask = abs.(transpose(vals) .- vals) .< tol + gaugepart = view(aVdΔV, mask) + Δgauge = norm(gaugepart, Inf) + Δgauge < tol || + @warn "`eigsolve` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + aVdΔV .= aVdΔV .* safe_inv.(transpose(vals) .- vals, tol) + for i in 1:n + aVdΔV[i, i] += real(Δvals[i]) + end + + zs = similar(Δvecs) + for i in 1:n + z = scale(vecs[1], aVdΔV[1, i]) + for j in 2:n + z = VectorInterface.add!!(z, vecs[j], aVdΔV[j, i]) + end + zs[i] = z + end + + # components in orthogonal subspace + sylvesterarg = similar(Δvecs) + for i in 1:n + y = zerovector(vecs[1]) + if !(Δvecs[i] isa AbstractZero) + y = VectorInterface.add!!(y, Δvecs[i], +1) + for j in 1:n + y = VectorInterface.add!!(y, vecs[j], -VdΔV[j, i]) end - ws[i] = w[1] end + sylvesterarg[i] = y + end - ∂f = f_pullbacks[1](ws[1])[1] - for i in 2:length(ws) - ∂f = ChainRulesCore.add!!(∂f, f_pullbacks[i](ws[i])[1]) + W₀ = (zerovector(vecs[1]), one.(vals)) + P = orthogonalcomplementprojector(vecs, n) + rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg + eigsolve(W₀, n, reverse_wich(which), alg_rrule) do W + w, x = W + w′ = fᴴ(P(w)) + @inbounds for i in 1:length(x) # length(x) = n but let us not use outer variables + w′ = VectorInterface.add!!(w′, ΔV[i], -x[i]) + end + return (w′, vals .* x) end - return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg end - return (vals, vecs, info), eigsolve_pullback + if info.converged >= n && reverse_info.converged < n + @warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did" + end + + # cleanup and construct final result + ws = zs + tol = alg_rrule.tol + for i in 1:n + w, x = Ws[i] + _, ic = findmax(abs, x) + factor = 1 / x[ic] + x[ic] = zero(x[ic]) + error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic]))) + error < tol || + @warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error" + ws[ic] = VectorInterface.add!!(zs[ic], P(w), -factor) + end + return ws end diff --git a/ext/KrylovKitChainRulesCoreExt/linsolve.jl b/ext/KrylovKitChainRulesCoreExt/linsolve.jl index 10284b2..235c166 100644 --- a/ext/KrylovKitChainRulesCoreExt/linsolve.jl +++ b/ext/KrylovKitChainRulesCoreExt/linsolve.jl @@ -1,91 +1,107 @@ -function ChainRulesCore.rrule(::typeof(linsolve), - A::AbstractMatrix, - b::AbstractVector, - x₀, - algorithm, - a₀, - a₁) - (x, info) = linsolve(A, b, x₀, algorithm, a₀, a₁) - project_A = ProjectTo(A) - - function linsolve_pullback(X̄) - x̄ = unthunk(X̄[1]) - ∂self = NoTangent() - ∂x₀ = ZeroTangent() - ∂algorithm = NoTangent() - ∂b, reverse_info = linsolve(A', x̄, (zero(a₀) * zero(a₁)) * x̄, algorithm, conj(a₀), - conj(a₁)) - if info.converged > 0 && reverse_info.converged == 0 - @warn "The cotangent linear problem did not converge, whereas the primal linear problem did." - end - if A isa StridedMatrix - ∂A = InplaceableThunk(Ā -> mul!(Ā, ∂b, x', -conj(a₁), true), - @thunk(-conj(a₁) * ∂b * x')) - else - ∂A = @thunk(project_A(-conj(a₁) * ∂b * x')) - end - ∂a₀ = @thunk(-dot(x, ∂b)) - if a₀ == zero(a₀) && a₁ == one(a₁) - ∂a₁ = @thunk(-dot(b, ∂b)) - else - ∂a₁ = @thunk(-dot((b - a₀ * x) / a₁, ∂b)) - end - return ∂self, ∂A, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁ - end - return (x, info), linsolve_pullback -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, - ::typeof(linsolve), - A::AbstractMatrix, - b::AbstractVector, - x₀, - algorithm, - a₀, - a₁) - return rrule(linsolve, A, b, x₀, algorithm, a₀, a₁) -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, +function ChainRulesCore.rrule(config::RuleConfig, ::typeof(linsolve), f, b, x₀, - algorithm, + alg_primal, a₀, - a₁) - x, info = linsolve(f, b, x₀, algorithm, a₀, a₁) - - # f defines a linear map => pullback defines action of the adjoint - (y, f_pullback) = rrule_via_ad(config, f, x) - fᴴ(xᴴ) = f_pullback(xᴴ)[2] - # TODO can we avoid computing f_pullback if algorithm isa Union{CG,MINRES}? + a₁; alg_rrule=alg_primal) + (x, info) = linsolve(f, b, x₀, alg_primal, a₀, a₁) + T, fᴴ, construct∂f = _prepare_inputs(config, f, (x,), alg_primal) function linsolve_pullback(X̄) x̄ = unthunk(X̄[1]) ∂self = NoTangent() ∂x₀ = ZeroTangent() ∂algorithm = NoTangent() - T = VectorInterface.promote_scale(VectorInterface.promote_scale(x̄, a₀), - scalartype(a₁)) - ∂b, reverse_info = linsolve(fᴴ, x̄, zerovector(x̄, T), algorithm, conj(a₀), + ∂b, reverse_info = linsolve(fᴴ, x̄, (zero(a₀) * zero(a₁)) * x̄, alg_rrule, conj(a₀), conj(a₁)) - if reverse_info.converged == 0 - @warn "Linear problem for reverse rule did not converge." reverse_info - end - ∂f = @thunk(f_pullback(scale(∂b, -conj(a₁)))[1]) + info.converged > 0 && reverse_info.converged == 0 && + @warn "`linsolve` cotangent problem did not converge, whereas the primal linear problem did" + + ∂f = construct∂f((scale(∂b, -conj(a₁)),)) ∂a₀ = @thunk(-inner(x, ∂b)) - # ∂a₁ = @thunk(-dot(f(x), ∂b)) if a₀ == zero(a₀) && a₁ == one(a₁) ∂a₁ = @thunk(-inner(b, ∂b)) else - ∂a₁ = @thunk(-inner(scale!!(add(b, x, -a₀), inv(a₁)), ∂b)) + ∂a₁ = @thunk(-inner(add(b, x, -a₀ / a₁, +1 / a₁), ∂b)) end return ∂self, ∂f, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁ end return (x, info), linsolve_pullback end +# function generate_linsolve_pullback(alg_rrule, A::AbstractMatrix, b::AbstractVector, a₀, a₁, +# x, info, alg_primal) +# project_A = ProjectTo(A) + +# function linsolve_pullback(X̄) +# x̄ = unthunk(X̄[1]) +# ∂self = NoTangent() +# ∂x₀ = ZeroTangent() +# ∂algorithm = NoTangent() +# ∂b, reverse_info = linsolve(A', x̄, (zero(a₀) * zero(a₁)) * x̄, alg_rrule, conj(a₀), +# conj(a₁)) +# if info.converged > 0 && reverse_info.converged == 0 +# @warn "The cotangent linear problem did not converge, whereas the primal linear problem did." +# end +# if A isa StridedMatrix +# ∂A = InplaceableThunk(Ā -> mul!(Ā, ∂b, x', -conj(a₁), true), +# @thunk(-conj(a₁) * ∂b * x')) +# else +# ∂A = @thunk(project_A(-conj(a₁) * ∂b * x')) +# end +# ∂a₀ = @thunk(-dot(x, ∂b)) +# if a₀ == zero(a₀) && a₁ == one(a₁) +# ∂a₁ = @thunk(-dot(b, ∂b)) +# else +# ∂a₁ = @thunk(-dot((b - a₀ * x) / a₁, ∂b)) +# end +# return ∂self, ∂A, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁ +# end +# return linsolve_pullback +# end + +# function generate_linsolve_pullback(config::RuleConfig{>:HasReverseMode}, +# alg_rrule, +# f, +# b, +# a₀, +# a₁, +# x, +# info, +# alg_primal) + +# # f defines a linear map => pullback defines action of the adjoint +# (y, f_pullback) = rrule_via_ad(config, f, x) +# fᴴ(xᴴ) = f_pullback(xᴴ)[2] +# # TODO can we avoid computing f_pullback if algorithm isa Union{CG,MINRES}? + +# function linsolve_pullback(X̄) +# x̄ = unthunk(X̄[1]) +# ∂self = NoTangent() +# ∂x₀ = ZeroTangent() +# ∂algorithm = NoTangent() +# T = VectorInterface.promote_scale(VectorInterface.promote_scale(x̄, a₀), +# scalartype(a₁)) +# ∂b, reverse_info = linsolve(fᴴ, x̄, zerovector(x̄, T), alg_rrule, conj(a₀), +# conj(a₁)) +# if reverse_info.converged == 0 +# @warn "Linear problem for reverse rule did not converge." reverse_info +# end +# ∂f = @thunk(f_pullback(scale(∂b, -conj(a₁)))[1]) +# ∂a₀ = @thunk(-inner(x, ∂b)) +# # ∂a₁ = @thunk(-dot(f(x), ∂b)) +# if a₀ == zero(a₀) && a₁ == one(a₁) +# ∂a₁ = @thunk(-inner(b, ∂b)) +# else +# ∂a₁ = @thunk(-inner(scale!!(add(b, x, -a₀), inv(a₁)), ∂b)) +# end +# return ∂self, ∂f, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁ +# end +# return linsolve_pullback +# end + # frule - currently untested function ChainRulesCore.frule((_, ΔA, Δb, Δx₀, _, Δa₀, Δa₁)::Tuple, ::typeof(linsolve), diff --git a/ext/KrylovKitChainRulesCoreExt/utilities.jl b/ext/KrylovKitChainRulesCoreExt/utilities.jl new file mode 100644 index 0000000..7528c31 --- /dev/null +++ b/ext/KrylovKitChainRulesCoreExt/utilities.jl @@ -0,0 +1,92 @@ +safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a) + +function orthogonalcomplementprojector(vecs, n) + function projector(w) + w′ = scale(w, 1) + @inbounds for i in 1:n + w′ = add!(w′, vecs[i], -inner(vecs[i], w)) + end + return w′ + end + return projector +end +function orthogonalcomplementprojector(vecs, n, G::Cholesky) + overlaps = zeros(eltype(G), n) + function projector(w) + @inbounds for i in 1:n + overlaps[i] = inner(vecs[i], w) + end + overlaps = ldiv!(G, overlaps) + w′ = scale(w, 1) + @inbounds for i in 1:n + w′ = VectorInterface.add!!(w′, vecs[i], -overlaps[i]) + end + return w′ + end + return projector +end + +function reverse_wich(which) + by, rev = KrylovKit.eigsort(which) + return EigSorter(by ∘ conj, rev) +end + +function _prepare_inputs(config, f, vecs, alg_primal) + T = scalartype(vecs[1]) + config isa RuleConfig{>:HasReverseMode} || + throw(ArgumentError("`eigsolve` reverse-mode AD requires AD engine that supports calling back into AD")) + f_pullbacks = map(x -> rrule_via_ad(config, f, x)[2], vecs) + if alg_primal isa Lanczos + fᴴ = v -> f(v) + else + fᴴ = v -> f_pullbacks[1](v)[2] + end + construct∂f = let f_pullbacks = f_pullbacks + function (ws) + ∂f = f_pullbacks[1](ws[1])[1] + for i in 2:length(ws) + ∂f = ChainRulesCore.add!!(∂f, f_pullbacks[i](ws[i])[1]) + end + return ∂f + end + end + return T, fᴴ, construct∂f +end + +function _prepare_inputs(config, A::AbstractMatrix, vecs, alg_primal) + T = eltype(vecs[1]) # will be real for real symmetric problems and complex otherwise + fᴴ = v -> A' * v + if A isa StridedMatrix + construct∂A = ws -> InplaceableThunk(Ā -> _buildĀ!(Ā, ws, vecs), + @thunk(_buildĀ!(zero(A), ws, vecs))) + else + construct∂A = let project_A = ProjectTo(A) + ws -> @thunk(project_A(_buildĀ!(zero(A), ws, vecs))) + end + end + return T, fᴴ, construct∂A +end + +function _buildĀ!(Ā, ws, vs) + for i in 1:length(ws) + w = ws[i] + v = vs[i] + if !(w isa AbstractZero) + if eltype(Ā) <: Real && eltype(w) <: Complex + mul!(Ā, _realview(w), _realview(v)', +1, +1) + mul!(Ā, _imagview(w), _imagview(v)', +1, +1) + else + mul!(Ā, w, v', +1, 1) + end + end + end + return Ā +end + +function _realview(v::AbstractVector{Complex{T}}) where {T} + return view(reinterpret(T, v), 2 * (1:length(v)) .- 1) +end + +function _imagview(v::AbstractVector{Complex{T}}) where {T} + return view(reinterpret(T, v), 2 * (1:length(v))) +end diff --git a/src/eigsolve/arnoldi.jl b/src/eigsolve/arnoldi.jl index b5ef871..91d8530 100644 --- a/src/eigsolve/arnoldi.jl +++ b/src/eigsolve/arnoldi.jl @@ -127,7 +127,7 @@ function schursolve(A, x₀, howmany::Int, which::Selector, alg::Arnoldi) ConvergenceInfo(converged, residuals, normresiduals, numiter, numops) end -function eigsolve(A, x₀, howmany::Int, which::Selector, alg::Arnoldi) +function eigsolve(A, x₀, howmany::Int, which::Selector, alg::Arnoldi; alg_rrule=alg) T, U, fact, converged, numiter, numops = _schursolve(A, x₀, howmany, which, alg) if eltype(T) <: Real && howmany < length(fact) && T[howmany + 1, howmany] != 0 howmany += 1 diff --git a/src/eigsolve/eigsolve.jl b/src/eigsolve/eigsolve.jl index 786b4f3..dc2a4fb 100644 --- a/src/eigsolve/eigsolve.jl +++ b/src/eigsolve/eigsolve.jl @@ -190,9 +190,8 @@ function eigsolve(f, x₀, howmany::Int=1, which::Selector=:LM; kwargs...) error("Eigenvalue selector which = $which invalid: real eigenvalues expected with Lanczos algorithm") end elseif T <: Real - if which == :LI || - which == :SI || - (which isa EigSorter && which.by(+im) != which.by(-im)) + by, rev = eigsort(which) + if by(+im) != by(-im) error("Eigenvalue selector which = $which invalid because it does not treat `λ` and `conj(λ)` equally: work in complex arithmetic by providing a complex starting vector `x₀`") end @@ -203,14 +202,14 @@ end function eigselector(f, T::Type; issymmetric::Bool=false, - ishermitian::Bool=issymmetric && !(T <: Complex), + ishermitian::Bool=issymmetric && (T <: Real), krylovdim::Int=KrylovDefaults.krylovdim, maxiter::Int=KrylovDefaults.maxiter, tol::Real=KrylovDefaults.tol, orth::Orthogonalizer=KrylovDefaults.orth, eager::Bool=false, verbosity::Int=0) - if (issymmetric && !(T <: Complex)) || ishermitian + if (T <: Real && issymmetric) || ishermitian return Lanczos(; krylovdim=krylovdim, maxiter=maxiter, tol=tol, @@ -235,7 +234,8 @@ function eigselector(A::AbstractMatrix, tol::Real=KrylovDefaults.tol, orth::Orthogonalizer=KrylovDefaults.orth, eager::Bool=false, - verbosity::Int=0) + verbosity::Int=0, + alg_rrule=nothing) if (T <: Real && issymmetric) || ishermitian return Lanczos(; krylovdim=krylovdim, maxiter=maxiter, diff --git a/src/eigsolve/lanczos.jl b/src/eigsolve/lanczos.jl index 097c7e1..5d64a87 100644 --- a/src/eigsolve/lanczos.jl +++ b/src/eigsolve/lanczos.jl @@ -1,4 +1,9 @@ -function eigsolve(A, x₀, howmany::Int, which::Selector, alg::Lanczos) +function eigsolve(A, x₀, howmany::Int, which::Selector, alg::Lanczos; + alg_rrule=Arnoldi(; tol=alg.tol, + krylovdim=alg.krylovdim, + maxiter=alg.maxiter, + eager=alg.eager, + orth=alg.orth)) krylovdim = alg.krylovdim maxiter = alg.maxiter howmany > krylovdim && diff --git a/src/linsolve/bicgstab.jl b/src/linsolve/bicgstab.jl index 2ff7f87..9ab9a77 100644 --- a/src/linsolve/bicgstab.jl +++ b/src/linsolve/bicgstab.jl @@ -1,4 +1,4 @@ -function linsolve(operator, b, x₀, alg::BiCGStab, a₀::Number=0, a₁::Number=1) +function linsolve(operator, b, x₀, alg::BiCGStab, a₀::Number=0, a₁::Number=1; alg_rrule=alg) # Initial function operation and division defines number type y₀ = apply(operator, x₀) T = typeof(inner(b, y₀) / norm(b) * one(a₀) * one(a₁)) diff --git a/src/linsolve/gmres.jl b/src/linsolve/gmres.jl index 129c43d..56f28f3 100644 --- a/src/linsolve/gmres.jl +++ b/src/linsolve/gmres.jl @@ -1,4 +1,4 @@ -function linsolve(operator, b, x₀, alg::GMRES, a₀::Number=0, a₁::Number=1) +function linsolve(operator, b, x₀, alg::GMRES, a₀::Number=0, a₁::Number=1; alg_rrule=alg) # Initial function operation and division defines number type y₀ = apply(operator, x₀) T = typeof(inner(b, y₀) / norm(b) * one(a₀) * one(a₁)) diff --git a/src/linsolve/linsolve.jl b/src/linsolve/linsolve.jl index 6b4359f..8e18135 100644 --- a/src/linsolve/linsolve.jl +++ b/src/linsolve/linsolve.jl @@ -119,22 +119,18 @@ function linselector(f, orth=KrylovDefaults.orth, verbosity::Int=0) if (T <: Real && issymmetric) || ishermitian - isposdef && + if isposdef return CG(; maxiter=krylovdim * maxiter, tol=tol, verbosity=verbosity) - # TODO: implement MINRES for symmetric but not posdef; for now use GRMES - # return MINRES(krylovdim*maxiter, tol=tol) - return GMRES(; krylovdim=krylovdim, - maxiter=maxiter, - tol=tol, - orth=orth, - verbosity=verbosity) - else - return GMRES(; krylovdim=krylovdim, - maxiter=maxiter, - tol=tol, - orth=orth, - verbosity=verbosity) + else + # TODO: implement MINRES for symmetric but not posdef; for now use GRMES + # return MINRES(krylovdim*maxiter, tol=tol) + end end + return GMRES(; krylovdim=krylovdim, + maxiter=maxiter, + tol=tol, + orth=orth, + verbosity=verbosity) end function linselector(A::AbstractMatrix, b, @@ -150,20 +146,16 @@ function linselector(A::AbstractMatrix, orth=KrylovDefaults.orth, verbosity::Int=0) if (T <: Real && issymmetric) || ishermitian - isposdef && + if isposdef return CG(; maxiter=krylovdim * maxiter, tol=tol, verbosity=verbosity) - # TODO: implement MINRES for symmetric but not posdef; for now use GRMES - # return MINRES(krylovdim*maxiter, tol=tol, verbosity = verbosity) - return GMRES(; krylovdim=krylovdim, - maxiter=maxiter, - tol=tol, - orth=orth, - verbosity=verbosity) - else - return GMRES(; krylovdim=krylovdim, - maxiter=maxiter, - tol=tol, - orth=orth, - verbosity=verbosity) + else + # TODO: implement MINRES for symmetric but not posdef; for now use GRMES + # return MINRES(krylovdim*maxiter, tol=tol) + end end + return GMRES(; krylovdim=krylovdim, + maxiter=maxiter, + tol=tol, + orth=orth, + verbosity=verbosity) end diff --git a/test/ad.jl b/test/ad.jl index 095a9fd..fa8ba94 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -103,7 +103,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences Random.seed!(123456789) fdm = ChainRulesTestUtils._fdm -precision(T::Type{<:Number}) = eps(real(T))^(2 / 3) +tolerance(T::Type{<:Number}) = eps(real(T))^(2 / 3) n = 10 N = 30 @@ -138,7 +138,7 @@ function build_mat_example(A, x, howmany::Int, which, alg) info′.converged < howmany && @warn "eigsolve did not converge" for i in 1:howmany d = dot(vecs[i], vecs′[i]) - @assert abs(d) > precision(eltype(A)) + @assert abs(d) > tolerance(eltype(A)) vecs′[i] = vecs′[i] / d end catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) @@ -200,7 +200,7 @@ function build_fun_example(A, x, c, d, howmany::Int, which, alg) info′.converged < howmany′ && @warn "eigsolve did not converge" for i in 1:howmany′ normfix = dot(vecs[i], vecs′[i]) - @assert abs(normfix) > precision(eltype(A)) + @assert abs(normfix) > tolerance(eltype(A)) vecs′[i] = vecs′[i] / normfix end catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) @@ -228,7 +228,7 @@ end x /= norm(x) howmany = 3 - alg = Arnoldi(; tol=cond(A) * eps(real(T)), krylovdim=n) + alg = Arnoldi(; tol=2 * cond(A) * eps(real(T)), krylovdim=n) mat_example_ad, mat_example_fd, Avec, xvec, vals, vecs, howmany = build_mat_example(A, x, howmany, @@ -239,8 +239,8 @@ end (JA′, Jx′) = Zygote.jacobian(mat_example_ad, Avec, xvec) # finite difference comparison using some kind of tolerance heuristic - @test JA ≈ JA′ rtol = (T <: Complex ? 4n : n) * cond(A) * precision(T) - @test Jx ≈ zero(Jx) atol = (T <: Complex ? 4n : n) * cond(A) * precision(T) + @test JA ≈ JA′ rtol = (T <: Complex ? 4n : n) * cond(A) * tolerance(T) + @test norm(Jx, Inf) < (T <: Complex ? 4n : n) * cond(A) * tolerance(T) @test Jx′ == zero(Jx) # some analysis @@ -261,7 +261,7 @@ end end # test orthogonality of vecs and ∂vecs for i in 1:howmany - @test all(<(precision(T)), abs.(vecs[i]' * ∂vecs[i])) + @test all(<(tolerance(T)), abs.(vecs[i]' * ∂vecs[i])) end end end