Skip to content

Commit

Permalink
Move branch from fork to main repository
Browse files Browse the repository at this point in the history
AD extension package
  • Loading branch information
lkdvos authored May 13, 2024
2 parents da91706 + bdf7a97 commit 8f655e6
Show file tree
Hide file tree
Showing 6 changed files with 484 additions and 9 deletions.
16 changes: 12 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,37 @@ authors = ["Jutho Haegeman"]
version = "0.7.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
KrylovKitChainRulesCoreExt = "ChainRulesCore"

[compat]
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
GPUArraysCore = "0.1"
VectorInterface = "0.4"
LinearAlgebra = "1"
Random = "1"
PackageExtensionCompat = "1"
Printf = "1"
Random = "1"
Test = "1"
TestExtras = "0.2"
VectorInterface = "0.4"
Zygote = "0.6"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -35,4 +43,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Aqua", "Random", "TestExtras", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
test = ["Test", "Aqua", "Random", "TestExtras", "ChainRulesTestUtils", "ChainRulesCore", "FiniteDifferences", "Zygote"]
11 changes: 11 additions & 0 deletions ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module KrylovKitChainRulesCoreExt

using KrylovKit
using ChainRulesCore
using LinearAlgebra
using VectorInterface

include("linsolve.jl")
include("eigsolve.jl")

end # module
249 changes: 249 additions & 0 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
function ChainRulesCore.rrule(::typeof(eigsolve),
A::AbstractMatrix,
x₀,
howmany,
which,
alg)
(vals, vecs, info) = eigsolve(A, x₀, howmany, which, alg)
project_A = ProjectTo(A)
T = eltype(vecs[1]) # will be real for real symmetric problems and complex otherwise

function eigsolve_pullback(ΔX)
_Δvals = unthunk(ΔX[1])
_Δvecs = unthunk(ΔX[2])

∂self = NoTangent()
∂x₀ = ZeroTangent()
∂howmany = NoTangent()
∂which = NoTangent()
∂alg = NoTangent()
if _Δvals isa AbstractZero && _Δvecs isa AbstractZero
∂A = ZeroTangent()
return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg
end

if _Δvals isa AbstractZero
Δvals = fill(NoTangent(), length(Δvecs))
else
Δvals = _Δvals
end
if _Δvecs isa AbstractZero
Δvecs = fill(NoTangent(), length(Δvals))
else
Δvecs = _Δvecs
end

@assert length(Δvals) == length(Δvecs)
@assert length(Δvals) <= length(vals)

# Determine algorithm to solve linear problem
# TODO: Is there a better choice? Should we make this user configurable?
linalg = GMRES(;
tol=alg.tol,
krylovdim=alg.krylovdim,
maxiter=alg.maxiter,
orth=alg.orth)

ws = similar(vecs, length(Δvecs))
for i in 1:length(Δvecs)
Δλ = Δvals[i]
Δv = Δvecs[i]
λ = vals[i]
v = vecs[i]

# First threat special cases
if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution
ws[i] = Δv # some kind of zero
continue
end
if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution
ws[i] = Δλ * v
continue
end

# General case :
if isa(Δv, AbstractZero)
b = RecursiveVec(zero(T) * v, T[Δλ])
else
@assert isa(Δv, typeof(v))
b = RecursiveVec(Δv, T[Δλ])
end

if i > 1 && eltype(A) <: Real &&
vals[i] == conj(vals[i - 1]) && Δvals[i] == conj(Δvals[i - 1]) &&
vecs[i] == conj(vecs[i - 1]) && Δvecs[i] == conj(Δvecs[i - 1])
ws[i] = conj(ws[i - 1])
continue
end

w, reverse_info = let λ = λ, v = v, Aᴴ = A'
linsolve(b, zero(T) * b, linalg) do x
x1, x2 = x
γ = 1
# γ can be chosen freely and does not affect the solution theoretically
# The current choice guarantees that the extended matrix is Hermitian if A is
# TODO: is this the best choice in all cases?
y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, A' * x1))
y2 = T[-dot(v, x1)]
return RecursiveVec(y1, y2)
end
end
if info.converged >= i && reverse_info.converged == 0
@warn "The cotangent linear problem did not converge, whereas the primal eigenvalue problem did."
end
ws[i] = w[1]
end

if A isa StridedMatrix
∂A = InplaceableThunk(Ā -> _buildĀ!(Ā, ws, vecs),
@thunk(_buildĀ!(zero(A), ws, vecs)))
else
∂A = @thunk(project_A(_buildĀ!(zero(A), ws, vecs)))
end
return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg
end
return (vals, vecs, info), eigsolve_pullback
end

function _buildĀ!(Ā, ws, vs)
for i in 1:length(ws)
w = ws[i]
v = vs[i]
if !(w isa AbstractZero)
if eltype(Ā) <: Real && eltype(w) <: Complex
mul!(Ā, _realview(w), _realview(v)', -1, 1)
mul!(Ā, _imagview(w), _imagview(v)', -1, 1)
else
mul!(Ā, w, v', -1, 1)
end
end
end
return
end
function _realview(v::AbstractVector{Complex{T}}) where {T}
return view(reinterpret(T, v), 2 * (1:length(v)) .- 1)
end
function _imagview(v::AbstractVector{Complex{T}}) where {T}
return view(reinterpret(T, v), 2 * (1:length(v)))
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode},
::typeof(eigsolve),
A::AbstractMatrix,
x₀,
howmany,
which,
alg)
return ChainRulesCore.rrule(eigsolve, A, x₀, howmany, which, alg)
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode},
::typeof(eigsolve),
f,
x₀,
howmany,
which,
alg)
(vals, vecs, info) = eigsolve(f, x₀, howmany, which, alg)
T = typeof(dot(vecs[1], vecs[1]))
f_pullbacks = map(x -> rrule_via_ad(config, f, x)[2], vecs)

function eigsolve_pullback(ΔX)
_Δvals = unthunk(ΔX[1])
_Δvecs = unthunk(ΔX[2])

∂self = NoTangent()
∂x₀ = ZeroTangent()
∂howmany = NoTangent()
∂which = NoTangent()
∂alg = NoTangent()
if _Δvals isa AbstractZero && _Δvecs isa AbstractZero
∂A = ZeroTangent()
return (∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg)
end

if _Δvals isa AbstractZero
Δvals = fill(NoTangent(), howmany)
else
Δvals = _Δvals
end
if _Δvecs isa AbstractZero
Δvecs = fill(NoTangent(), howmany)
else
Δvecs = _Δvecs
end

@assert length(Δvals) == length(Δvecs)

# Determine algorithm to solve linear problem
# TODO: Is there a better choice? Should we make this user configurable?
linalg = GMRES(;
tol=alg.tol,
krylovdim=alg.krylovdim,
maxiter=alg.maxiter,
orth=alg.orth)
# linalg = BiCGStab(;
# tol = alg.tol,
# maxiter = alg.maxiter*alg.krylovdim,
# )

ws = similar(Δvecs)
for i in 1:length(Δvecs)
Δλ = Δvals[i]
Δv = Δvecs[i]
λ = vals[i]
v = vecs[i]

# First threat special cases
if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution
ws[i] = Δv # some kind of zero
continue
end
if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution
ws[i] = Δλ * v
continue
end

# General case :
if isa(Δv, AbstractZero)
b = RecursiveVec(zero(T) * v, T[-Δλ])
else
@assert isa(Δv, typeof(v))
b = RecursiveVec(-Δv, T[-Δλ])
end

# TODO: is there any analogy to this for general vector-like user types
# if i > 1 && eltype(A) <: Real &&
# vals[i] == conj(vals[i-1]) && Δvals[i] == conj(Δvals[i-1]) &&
# vecs[i] == conj(vecs[i-1]) && Δvecs[i] == conj(Δvecs[i-1])
#
# ws[i] = conj(ws[i-1])
# continue
# end

w, reverse_info = let λ = λ, v = v, fᴴ = x -> f_pullbacks[i](x)[2]
linsolve(b, zero(T) * b, linalg) do x
x1, x2 = x
γ = 1
# γ can be chosen freely and does not affect the solution theoretically
# The current choice guarantees that the extended matrix is Hermitian if A is
# TODO: is this the best choice in all cases?
y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, fᴴ(x1)))
y2 = T[-dot(v, x1)]
return RecursiveVec(y1, y2)
end
end
if info.converged >= i && reverse_info.converged == 0
@warn "The cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did."
end
ws[i] = w[1]
end

∂f = f_pullbacks[1](ws[1])[1]
for i in 2:length(ws)
∂f = ChainRulesCore.add!!(∂f, f_pullbacks[i](ws[i])[1])
end
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg
end
return (vals, vecs, info), eigsolve_pullback
end
File renamed without changes.
9 changes: 4 additions & 5 deletions src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ using VectorInterface
using VectorInterface: add!!
using LinearAlgebra
using Printf
using ChainRulesCore
using GPUArraysCore
using PackageExtensionCompat
const IndexRange = AbstractRange{Int}

export linsolve, eigsolve, geneigsolve, svdsolve, schursolve, exponentiate, expintegrator
Expand Down Expand Up @@ -60,7 +60,9 @@ enable_threads() = set_num_threads(Base.Threads.nthreads())
disable_threads() = set_num_threads(1)

function __init__()
return set_num_threads(Base.Threads.nthreads())
@require_extensions
set_num_threads(Base.Threads.nthreads())
return nothing
end

struct SplitRange
Expand Down Expand Up @@ -234,9 +236,6 @@ include("linsolve/bicgstab.jl")
include("matrixfun/exponentiate.jl")
include("matrixfun/expintegrator.jl")

# rules for automatic differentation
include("adrules/linsolve.jl")

# custom vector types
include("recursivevec.jl")
include("innerproductvec.jl")
Expand Down
Loading

0 comments on commit 8f655e6

Please sign in to comment.