From 1eccb72b742830cc0851661e5cb5082f7f0dd63e Mon Sep 17 00:00:00 2001 From: Lukas <37111893+lkdvos@users.noreply.github.com> Date: Mon, 25 Mar 2024 09:33:49 +0100 Subject: [PATCH 1/6] Add DOI to README --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 88a4f56..22ba655 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ A Julia package collecting a number of Krylov-based algorithms for linear proble value and eigenvalue problems and the application of functions of linear maps or operators to vectors. -| **Documentation** | **Build Status** | **License** | -|:-----------------:|:----------------:|:-----------:| -| [![][docs-stable-img]][docs-stable-url] [![][docs-dev-img]][docs-dev-url] | [![][aqua-img]][aqua-url] [![CI][github-img]][github-url] [![][codecov-img]][codecov-url] | [![license][license-img]][license-url] | +| **Documentation** | **Build Status** | **Digital Object Idenitifier** | **License** | +|:-----------------:|:----------------:|:---------------:|:-----------:| +| [![][docs-stable-img]][docs-stable-url] [![][docs-dev-img]][docs-dev-url] | [![][aqua-img]][aqua-url] [![CI][github-img]][github-url] [![][codecov-img]][codecov-url] | [![DOI][doi-img]][doi-url] | [![license][license-img]][license-url] | [docs-dev-img]: https://img.shields.io/badge/docs-dev-blue.svg [docs-dev-url]: https://jutho.github.io/KrylovKit.jl/latest @@ -26,6 +26,9 @@ to vectors. [license-img]: http://img.shields.io/badge/license-MIT-brightgreen.svg?style=flat [license-url]: LICENSE.md +[doi-img]: https://zenodo.org/badge/DOI/10.5281/zenodo.10622234.svg +[doi-url]: https://doi.org/10.5281/zenodo.10622234 + ## Release notes for the latest version ### v0.7 From 3cceb8d4db59e212be924f40eb251ede28ca6892 Mon Sep 17 00:00:00 2001 From: Lukas <37111893+lkdvos@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:21:50 +0100 Subject: [PATCH 2/6] Create CITATION.cff --- CITATION.cff | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 CITATION.cff diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..56223a2 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,12 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: +- family-names: "Haegeman" + given-names: "Jutho" + orcid: "https://orcid.org/0000-0002-0858-291X" + +title: "KrylovKit" +version: 0.7.0 +doi: 10.5281/zenodo.10622234 +date-released: 2024-03-14 +url: "https://github.com/Jutho/KrylovKit.jl" From 7e2a11661a8369011eb5dca3786db6d643dcfda4 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 25 Mar 2024 11:33:42 +0100 Subject: [PATCH 3/6] Update CI - Separates the nightly build to make the badge reflect the status of the CI. - Update Codecov actions --- .github/workflows/ci-nightly.yml | 39 +++++++++++++++++++++++++++++++ .github/workflows/ci.yml | 40 +++++--------------------------- 2 files changed, 45 insertions(+), 34 deletions(-) create mode 100644 .github/workflows/ci-nightly.yml diff --git a/.github/workflows/ci-nightly.yml b/.github/workflows/ci-nightly.yml new file mode 100644 index 0000000..679f455 --- /dev/null +++ b/.github/workflows/ci-nightly.yml @@ -0,0 +1,39 @@ +name: CI-nightly +on: + push: + branches: + - 'master' + - 'main' + - 'release-' + tags: '*' + pull_request: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test-nightly: + name: Julia nightly - ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: true + matrix: + version: + - 'nightly' + os: + - ubuntu-latest + - macOS-latest + - windows-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@latest + - uses: julia-actions/julia-runtest@latest diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 92e0fc7..96c2f40 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,10 +29,6 @@ jobs: - windows-latest arch: - x64 - # - x86 - # exclude: - # - os: macOS-latest - # arch: x86 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 @@ -43,7 +39,9 @@ jobs: - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: file: lcov.info test-multithreaded: @@ -61,10 +59,6 @@ jobs: - macOS-latest arch: - x64 - # - x86 - # exclude: - # - os: macOS-latest - # arch: x86 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 @@ -77,30 +71,8 @@ jobs: env: JULIA_NUM_THREADS: 4 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: file: lcov.info - test-nightly: - needs: test - name: Julia nightly - ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - 'nightly' - os: - - ubuntu-latest - - macOS-latest - - windows-latest - arch: - - x64 - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest From a20d41c53cb154115bab622c609c18d286c2e54c Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 25 Mar 2024 14:12:46 +0100 Subject: [PATCH 4/6] Bump version to v0.7.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 183993b..bda8f9e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KrylovKit" uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" authors = ["Jutho Haegeman"] -version = "0.7.0" +version = "0.7.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From ac5f08a1889e55142b30e69f273fd70a24c89326 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 26 Mar 2024 12:02:17 +0100 Subject: [PATCH 5/6] Move AD rules to package extension --- Project.toml | 16 ++++++++++++---- .../KrylovKitChainRulesCoreExt.jl | 11 +++++++++++ .../KrylovKitChainRulesCoreExt}/linsolve.jl | 0 src/KrylovKit.jl | 9 ++++----- 4 files changed, 27 insertions(+), 9 deletions(-) create mode 100644 ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl rename {src/adrules => ext/KrylovKitChainRulesCoreExt}/linsolve.jl (100%) diff --git a/Project.toml b/Project.toml index bda8f9e..7ece7f5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,29 +4,37 @@ authors = ["Jutho Haegeman"] version = "0.7.1" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +KrylovKitChainRulesCoreExt = "ChainRulesCore" + [compat] Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" FiniteDifferences = "0.12" GPUArraysCore = "0.1" -VectorInterface = "0.4" LinearAlgebra = "1" -Random = "1" +PackageExtensionCompat = "1" Printf = "1" +Random = "1" Test = "1" TestExtras = "0.2" +VectorInterface = "0.4" Zygote = "0.6" julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -35,4 +43,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Aqua", "Random", "TestExtras", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"] +test = ["Test", "Aqua", "Random", "TestExtras", "ChainRulesTestUtils", "ChainRulesCore", "FiniteDifferences", "Zygote"] diff --git a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl new file mode 100644 index 0000000..d2dbee2 --- /dev/null +++ b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl @@ -0,0 +1,11 @@ +module KrylovKitChainRulesCoreExt + +using KrylovKit +using ChainRulesCore +using LinearAlgebra +using VectorInterface + +include("linsolve.jl") + + +end # module diff --git a/src/adrules/linsolve.jl b/ext/KrylovKitChainRulesCoreExt/linsolve.jl similarity index 100% rename from src/adrules/linsolve.jl rename to ext/KrylovKitChainRulesCoreExt/linsolve.jl diff --git a/src/KrylovKit.jl b/src/KrylovKit.jl index 54a8da3..16311be 100644 --- a/src/KrylovKit.jl +++ b/src/KrylovKit.jl @@ -24,8 +24,8 @@ using VectorInterface using VectorInterface: add!! using LinearAlgebra using Printf -using ChainRulesCore using GPUArraysCore +using PackageExtensionCompat const IndexRange = AbstractRange{Int} export linsolve, eigsolve, geneigsolve, svdsolve, schursolve, exponentiate, expintegrator @@ -60,7 +60,9 @@ enable_threads() = set_num_threads(Base.Threads.nthreads()) disable_threads() = set_num_threads(1) function __init__() - return set_num_threads(Base.Threads.nthreads()) + @require_extensions + set_num_threads(Base.Threads.nthreads()) + return nothing end struct SplitRange @@ -234,9 +236,6 @@ include("linsolve/bicgstab.jl") include("matrixfun/exponentiate.jl") include("matrixfun/expintegrator.jl") -# rules for automatic differentation -include("adrules/linsolve.jl") - # custom vector types include("recursivevec.jl") include("innerproductvec.jl") From bdf7a97150190f86d3573840b0ad5f3beb836573 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 26 Mar 2024 13:23:28 +0100 Subject: [PATCH 6/6] Include eigsolve AD --- .../KrylovKitChainRulesCoreExt.jl | 2 +- ext/KrylovKitChainRulesCoreExt/eigsolve.jl | 249 ++++++++++++++++++ test/ad.jl | 208 +++++++++++++++ 3 files changed, 458 insertions(+), 1 deletion(-) create mode 100644 ext/KrylovKitChainRulesCoreExt/eigsolve.jl diff --git a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl index d2dbee2..97c61e7 100644 --- a/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl +++ b/ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl @@ -6,6 +6,6 @@ using LinearAlgebra using VectorInterface include("linsolve.jl") - +include("eigsolve.jl") end # module diff --git a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl new file mode 100644 index 0000000..fd695f3 --- /dev/null +++ b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl @@ -0,0 +1,249 @@ +function ChainRulesCore.rrule(::typeof(eigsolve), + A::AbstractMatrix, + x₀, + howmany, + which, + alg) + (vals, vecs, info) = eigsolve(A, x₀, howmany, which, alg) + project_A = ProjectTo(A) + T = eltype(vecs[1]) # will be real for real symmetric problems and complex otherwise + + function eigsolve_pullback(ΔX) + _Δvals = unthunk(ΔX[1]) + _Δvecs = unthunk(ΔX[2]) + + ∂self = NoTangent() + ∂x₀ = ZeroTangent() + ∂howmany = NoTangent() + ∂which = NoTangent() + ∂alg = NoTangent() + if _Δvals isa AbstractZero && _Δvecs isa AbstractZero + ∂A = ZeroTangent() + return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg + end + + if _Δvals isa AbstractZero + Δvals = fill(NoTangent(), length(Δvecs)) + else + Δvals = _Δvals + end + if _Δvecs isa AbstractZero + Δvecs = fill(NoTangent(), length(Δvals)) + else + Δvecs = _Δvecs + end + + @assert length(Δvals) == length(Δvecs) + @assert length(Δvals) <= length(vals) + + # Determine algorithm to solve linear problem + # TODO: Is there a better choice? Should we make this user configurable? + linalg = GMRES(; + tol=alg.tol, + krylovdim=alg.krylovdim, + maxiter=alg.maxiter, + orth=alg.orth) + + ws = similar(vecs, length(Δvecs)) + for i in 1:length(Δvecs) + Δλ = Δvals[i] + Δv = Δvecs[i] + λ = vals[i] + v = vecs[i] + + # First threat special cases + if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution + ws[i] = Δv # some kind of zero + continue + end + if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution + ws[i] = Δλ * v + continue + end + + # General case : + if isa(Δv, AbstractZero) + b = RecursiveVec(zero(T) * v, T[Δλ]) + else + @assert isa(Δv, typeof(v)) + b = RecursiveVec(Δv, T[Δλ]) + end + + if i > 1 && eltype(A) <: Real && + vals[i] == conj(vals[i - 1]) && Δvals[i] == conj(Δvals[i - 1]) && + vecs[i] == conj(vecs[i - 1]) && Δvecs[i] == conj(Δvecs[i - 1]) + ws[i] = conj(ws[i - 1]) + continue + end + + w, reverse_info = let λ = λ, v = v, Aᴴ = A' + linsolve(b, zero(T) * b, linalg) do x + x1, x2 = x + γ = 1 + # γ can be chosen freely and does not affect the solution theoretically + # The current choice guarantees that the extended matrix is Hermitian if A is + # TODO: is this the best choice in all cases? + y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, A' * x1)) + y2 = T[-dot(v, x1)] + return RecursiveVec(y1, y2) + end + end + if info.converged >= i && reverse_info.converged == 0 + @warn "The cotangent linear problem did not converge, whereas the primal eigenvalue problem did." + end + ws[i] = w[1] + end + + if A isa StridedMatrix + ∂A = InplaceableThunk(Ā -> _buildĀ!(Ā, ws, vecs), + @thunk(_buildĀ!(zero(A), ws, vecs))) + else + ∂A = @thunk(project_A(_buildĀ!(zero(A), ws, vecs))) + end + return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg + end + return (vals, vecs, info), eigsolve_pullback +end + +function _buildĀ!(Ā, ws, vs) + for i in 1:length(ws) + w = ws[i] + v = vs[i] + if !(w isa AbstractZero) + if eltype(Ā) <: Real && eltype(w) <: Complex + mul!(Ā, _realview(w), _realview(v)', -1, 1) + mul!(Ā, _imagview(w), _imagview(v)', -1, 1) + else + mul!(Ā, w, v', -1, 1) + end + end + end + return Ā +end +function _realview(v::AbstractVector{Complex{T}}) where {T} + return view(reinterpret(T, v), 2 * (1:length(v)) .- 1) +end +function _imagview(v::AbstractVector{Complex{T}}) where {T} + return view(reinterpret(T, v), 2 * (1:length(v))) +end + +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, + ::typeof(eigsolve), + A::AbstractMatrix, + x₀, + howmany, + which, + alg) + return ChainRulesCore.rrule(eigsolve, A, x₀, howmany, which, alg) +end + +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, + ::typeof(eigsolve), + f, + x₀, + howmany, + which, + alg) + (vals, vecs, info) = eigsolve(f, x₀, howmany, which, alg) + T = typeof(dot(vecs[1], vecs[1])) + f_pullbacks = map(x -> rrule_via_ad(config, f, x)[2], vecs) + + function eigsolve_pullback(ΔX) + _Δvals = unthunk(ΔX[1]) + _Δvecs = unthunk(ΔX[2]) + + ∂self = NoTangent() + ∂x₀ = ZeroTangent() + ∂howmany = NoTangent() + ∂which = NoTangent() + ∂alg = NoTangent() + if _Δvals isa AbstractZero && _Δvecs isa AbstractZero + ∂A = ZeroTangent() + return (∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg) + end + + if _Δvals isa AbstractZero + Δvals = fill(NoTangent(), howmany) + else + Δvals = _Δvals + end + if _Δvecs isa AbstractZero + Δvecs = fill(NoTangent(), howmany) + else + Δvecs = _Δvecs + end + + @assert length(Δvals) == length(Δvecs) + + # Determine algorithm to solve linear problem + # TODO: Is there a better choice? Should we make this user configurable? + linalg = GMRES(; + tol=alg.tol, + krylovdim=alg.krylovdim, + maxiter=alg.maxiter, + orth=alg.orth) + # linalg = BiCGStab(; + # tol = alg.tol, + # maxiter = alg.maxiter*alg.krylovdim, + # ) + + ws = similar(Δvecs) + for i in 1:length(Δvecs) + Δλ = Δvals[i] + Δv = Δvecs[i] + λ = vals[i] + v = vecs[i] + + # First threat special cases + if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution + ws[i] = Δv # some kind of zero + continue + end + if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution + ws[i] = Δλ * v + continue + end + + # General case : + if isa(Δv, AbstractZero) + b = RecursiveVec(zero(T) * v, T[-Δλ]) + else + @assert isa(Δv, typeof(v)) + b = RecursiveVec(-Δv, T[-Δλ]) + end + + # TODO: is there any analogy to this for general vector-like user types + # if i > 1 && eltype(A) <: Real && + # vals[i] == conj(vals[i-1]) && Δvals[i] == conj(Δvals[i-1]) && + # vecs[i] == conj(vecs[i-1]) && Δvecs[i] == conj(Δvecs[i-1]) + # + # ws[i] = conj(ws[i-1]) + # continue + # end + + w, reverse_info = let λ = λ, v = v, fᴴ = x -> f_pullbacks[i](x)[2] + linsolve(b, zero(T) * b, linalg) do x + x1, x2 = x + γ = 1 + # γ can be chosen freely and does not affect the solution theoretically + # The current choice guarantees that the extended matrix is Hermitian if A is + # TODO: is this the best choice in all cases? + y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, fᴴ(x1))) + y2 = T[-dot(v, x1)] + return RecursiveVec(y1, y2) + end + end + if info.converged >= i && reverse_info.converged == 0 + @warn "The cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did." + end + ws[i] = w[1] + end + + ∂f = f_pullbacks[1](ws[1])[1] + for i in 2:length(ws) + ∂f = ChainRulesCore.add!!(∂f, f_pullbacks[i](ws[i])[1]) + end + return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg + end + return (vals, vecs, info), eigsolve_pullback +end diff --git a/test/ad.jl b/test/ad.jl index 11de1b9..095a9fd 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -95,3 +95,211 @@ end end end end + +module EigsolveAD +using KrylovKit, LinearAlgebra +using Random, Test, TestExtras +using ChainRulesCore, ChainRulesTestUtils, Zygote, FiniteDifferences +Random.seed!(123456789) + +fdm = ChainRulesTestUtils._fdm +precision(T::Type{<:Number}) = eps(real(T))^(2 / 3) +n = 10 +N = 30 + +function build_mat_example(A, x, howmany::Int, which, alg) + Avec, A_fromvec = to_vec(A) + xvec, x_fromvec = to_vec(x) + + vals, vecs, info = eigsolve(A, x, howmany, which, alg) + info.converged < howmany && @warn "eigsolve did not converge" + if eltype(A) <: Real && length(vals) > howmany && + vals[howmany] == conj(vals[howmany + 1]) + howmany += 1 + end + + function mat_example_ad(Av, xv) + A′ = A_fromvec(Av) + x′ = x_fromvec(xv) + vals′, vecs′, info′ = eigsolve(A′, x′, howmany, which, alg) + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function mat_example_fd(Av, xv) + A′ = A_fromvec(Av) + x′ = x_fromvec(xv) + vals′, vecs′, info′ = eigsolve(A′, x′, howmany, which, alg) + info′.converged < howmany && @warn "eigsolve did not converge" + for i in 1:howmany + d = dot(vecs[i], vecs′[i]) + @assert abs(d) > precision(eltype(A)) + vecs′[i] = vecs′[i] / d + end + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + return mat_example_ad, mat_example_fd, Avec, xvec, vals, vecs, howmany +end + +function build_fun_example(A, x, c, d, howmany::Int, which, alg) + Avec, matfromvec = to_vec(A) + xvec, vecfromvec = to_vec(x) + cvec, = to_vec(c) + dvec, = to_vec(d) + + vals, vecs, info = eigsolve(x, howmany, which, alg) do y + return A * y + c * dot(d, y) + end + info.converged < howmany && @warn "eigsolve did not converge" + if eltype(A) <: Real && length(vals) > howmany && + vals[howmany] == conj(vals[howmany + 1]) + howmany += 1 + end + + fun_example_ad = let howmany′ = howmany + function (Av, xv, cv, dv) + A′ = matfromvec(Av) + x′ = vecfromvec(xv) + c′ = vecfromvec(cv) + d′ = vecfromvec(dv) + + vals′, vecs′, info′ = eigsolve(x′, howmany′, which, alg) do y + return A′ * y + c′ * dot(d′, y) + end + info′.converged < howmany′ && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + end + + fun_example_fd = let howmany′ = howmany + function (Av, xv, cv, dv) + A′ = matfromvec(Av) + x′ = vecfromvec(xv) + c′ = vecfromvec(cv) + d′ = vecfromvec(dv) + + vals′, vecs′, info′ = eigsolve(x′, howmany′, which, alg) do y + return A′ * y + c′ * dot(d′, y) + end + info′.converged < howmany′ && @warn "eigsolve did not converge" + for i in 1:howmany′ + normfix = dot(vecs[i], vecs′[i]) + @assert abs(normfix) > precision(eltype(A)) + vecs′[i] = vecs′[i] / normfix + end + catresults = vcat(vals′[1:howmany′], vecs′[1:howmany′]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + end + + return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany +end + +@timedtestset "Small eigsolve AD test" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + if T <: Complex + whichlist = (:LM, :SR, :LR, :SI, :LI) + else + whichlist = (:LM, :SR, :LR) + end + @testset for which in whichlist + A = 2 * (rand(T, (n, n)) .- one(T) / 2) + x = 2 * (rand(T, n) .- one(T) / 2) + x /= norm(x) + + howmany = 3 + alg = Arnoldi(; tol=cond(A) * eps(real(T)), krylovdim=n) + mat_example_ad, mat_example_fd, Avec, xvec, vals, vecs, howmany = build_mat_example(A, + x, + howmany, + which, + alg) + + (JA, Jx) = FiniteDifferences.jacobian(fdm, mat_example_fd, Avec, xvec) + (JA′, Jx′) = Zygote.jacobian(mat_example_ad, Avec, xvec) + + # finite difference comparison using some kind of tolerance heuristic + @test JA ≈ JA′ rtol = (T <: Complex ? 4n : n) * cond(A) * precision(T) + @test Jx ≈ zero(Jx) atol = (T <: Complex ? 4n : n) * cond(A) * precision(T) + @test Jx′ == zero(Jx) + + # some analysis + ∂vals = complex.(JA′[1:howmany, :], JA′[howmany * (n + 1) .+ (1:howmany), :]) + ∂vecs = map(1:howmany) do i + return complex.(JA′[(howmany + (i - 1) * n) .+ (1:n), :], + JA′[(howmany * (n + 2) + (i - 1) * n) .+ (1:n), :]) + end + if eltype(A) <: Complex # test holomorphicity / Cauchy-Riemann equations + # for eigenvalues + @test real(∂vals[:, 1:2:(2n^2)]) ≈ +imag(∂vals[:, 2:2:(2n^2)]) + @test imag(∂vals[:, 1:2:(2n^2)]) ≈ -real(∂vals[:, 2:2:(2n^2)]) + # and for eigenvectors + for i in 1:howmany + @test real(∂vecs[i][:, 1:2:(2n^2)]) ≈ +imag(∂vecs[i][:, 2:2:(2n^2)]) + @test imag(∂vecs[i][:, 1:2:(2n^2)]) ≈ -real(∂vecs[i][:, 2:2:(2n^2)]) + end + end + # test orthogonality of vecs and ∂vecs + for i in 1:howmany + @test all(<(precision(T)), abs.(vecs[i]' * ∂vecs[i])) + end + end + end +end +@timedtestset "Large eigsolve AD test" begin + @testset for T in (ComplexF64,) # (Float64, ComplexF64,) + # disable real case untill ChainRules.jl/issues/625 is fixed + if T <: Complex + whichlist = (:LM, :SI) + else + whichlist = (:SR,) + end + @testset for which in whichlist + A = rand(T, (N, N)) .- one(T) / 2 + A = I - (9 // 10) * A / maximum(abs, eigvals(A)) + x = 2 * (rand(T, N) .- one(T) / 2) + x /= norm(x) + c = 2 * (rand(T, N) .- one(T) / 2) + d = 2 * (rand(T, N) .- one(T) / 2) + + howmany = 2 + alg = Arnoldi(; tol=N * N * eps(real(T)), krylovdim=2n) + fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A, + x, + c, + d, + howmany, + which, + alg) + + (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, + cvec, dvec) + (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example_ad, Avec, xvec, cvec, dvec) + @test JA ≈ JA′ + @test Jc ≈ Jc′ + @test Jd ≈ Jd′ + end + end +end + +end