Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added StaticArrays extension for expv method #180

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
*.jl.mem
deps/deps.jl
Manifest.toml
.vscode
.vscode
.DS_Store
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExponentialUtilities"
uuid = "d4d017d3-3776-5f7e-afef-a10c40355c18"
authors = ["Chris Rackauckas <[email protected]>"]
version = "1.26.1"
uuid = "3927d2ad-0af0-4d29-8d72-bf799b39c66b"
jecs marked this conversation as resolved.
Show resolved Hide resolved
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
authors = ["Chris Rackauckas <[email protected]>", "José E. Cruz Serrallés <[email protected]>"]
version = "1.27.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -14,6 +14,12 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
libblastrampoline_jll = "8e850b90-86db-534c-a0d3-1478176c7d93"

[weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
ExponentialUtilitiesStaticArraysExt = "StaticArrays"

[compat]
Adapt = "3.4.0, 4"
Aqua = "0.8"
Expand Down
138 changes: 138 additions & 0 deletions ext/ExponentialUtilitiesStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
module ExponentialUtilitiesStaticArraysExt

export default_tolerance,theta,THETA32,THETA64

using StaticArrays
import Base: @propagate_inbounds
import LinearAlgebra: tr,I,opnorm,norm
import ExponentialUtilities

# Look-Up Table Generation
default_tolerance(::Type{T}) where {T <: AbstractFloat} = eps(T)/2
@inline function trexp(M::Integer,x::T) where {T}
y = T <: BigInt ? one(BigFloat) : T <: Integer ? one(Float64) : one(T)
for m ∈ M:-1:1
y = 1+x/m*y
end
return y
end
h(M::Integer,x::Number) = log(exp(-x)*trexp(M,x))
h̃(M::Integer,x::Number) = ifelse(isodd(M),-1,1)*h(M,-x)
function θf((M,ϵ)::Tuple{<:Integer,<:Number},x::Number)
return h̃(M+1,x)/x-ϵ
end
θf(M::Integer,ϵ::Number) = Base.Fix1(θf,(M,ϵ))
function θfp(M::Integer,x::Number)
Tk = trexp(M+1,-x)
Tkm1 = trexp(M,-x)
return ifelse(isodd(M),-1,1)/x^2*(log(Tk)+x*Tkm1/Tk)
end
θfp(M::Integer) = Base.Fix1(θfp,M)

function newton_find_zero(f::Function,dfdx::Function,x0::Real;xrtol::Real=eps(typeof(x0))/2,maxiter::Integer=100)
0 ≤ xrtol ≤ 1 || throw(DomainError(xrtol,"relative tolerance in x must be in [0,1]"))
maxiter > 0 || throw(DomainError(maxiter,"maxiter should be a positive integer"))
x, xp = x0, typemax(x0)
for _ ∈ 1:maxiter
xp = x
x -= f(x)/dfdx(x)
if abs(x-xp) ≤ xrtol*max(x,xp) || !isfinite(x)
break
end
end
return x
end
function calc_thetas(m_max::Integer,::Type{T};tol::T=default_tolerance(T)) where {T <: AbstractFloat}
m_max > 0 || throw(DomainError(m_max,"argument m_max must be positive"))
ϵ = BigFloat(tol)
θ = Vector{T}(undef,m_max+1)
@inbounds θ[1] = eps(T)
@inbounds for m=1:m_max
θ[m+1] = newton_find_zero(θf(m,ϵ),θfp(m),big(θ[m]),xrtol=ϵ)
end
return θ
end

const P_MAX = 8
const M_MAX = 55
const THETA32 = Tuple(calc_thetas(M_MAX,Float32))
const THETA64 = Tuple(calc_thetas(M_MAX,Float64))

@propagate_inbounds theta(::Type{Float64},m::Integer) = THETA64[m]
@propagate_inbounds theta(::Type{Float32},m::Integer) = THETA32[m]
@propagate_inbounds theta(::Type{Complex{T}},m::Integer) where {T} = theta(T,m)
@propagate_inbounds theta(::Type{T},::Integer) where {T} = throw(DomainError(T,"type must be either Float32 or Float64"))
@propagate_inbounds theta(x::Number,m::Integer) = theta(typeof(x),m)

# runtime parameter search
@propagate_inbounds @inline function calculate_s(α::T,m::I)::I where {T <: Number,I <: Integer}
return ceil(I,α/theta(T,m))
end
@propagate_inbounds @inline function parameter_search(nA::Number,m::I)::I where {I <: Integer}
return m*calculate_s(nA,m)
end
@propagate_inbounds @inline function parameters(A::SMatrix{N,N,T})::Tuple{Int,Int} where {N,T}
1 ≤ N ≤ 50 || throw(DomainError(N,"leading dimension of A must be ≤ 50; larger matrices require Higham's 1-norm estimation algorithm"))
nA = opnorm(A,1)
iszero(nA) && return (0,1)
@inbounds if nA ≤ 4theta(T,M_MAX)*P_MAX*(P_MAX+3)/(M_MAX*1)
mo = argmin(Base.Fix1(parameter_search,nA),1:M_MAX)
s = calculate_s(nA,mo)
return (mo,s)
else
Aᵐ = A*A
pη = √(opnorm(Aᵐ,1))
(Cmo::Int,mo::Int) = (typemax(Int),1)
for p ∈ 2:P_MAX
Aᵐ *= A
η = opnorm(Aᵐ,1)^inv(p+1)
α = max(pη,η)
pη = η
(Cmp::Int,mp::Int) = findmin(Base.Fix1(parameter_search,α),p*(p-1)-1:M_MAX)
(Cmo,mo) = min((Cmp,mp),(Cmo,mo))
end
s = max(Cmo÷mo,1)
return (mo,s)
end
end

# exponential matrix-vector product for SArray types
"""
expv(t::Number,A::SMatrix{N,N},v::SVector{N};kwarg...) → exp(t*A)*v

Computes the matrix-vector product exp(t*A)*v without forming exp(t*A) explicitly.
This implementation is based on the algorithm presented in Al-Mohy & Higham (2011).
Presently, the relative tolerance is fixed to eps(T)/2 where T is the type of the
output.
"""
@propagate_inbounds function ExponentialUtilities.expv(t::Number,A::SMatrix{N,N,T},v::SVector{N}; kwarg...) where {N,T}
Ti = promote_type(StaticArrays.arithmetic_closure(T),eltype(v))
N ≤ 4 && return exp(t*A)*v
Ai::SMatrix{N,N,Ti} = A

μ = tr(Ai)/N
Ai -= μ*I
Ai *= t
mo, s = parameters(Ai)
F = v
Ai /= s
η = exp(μ*t/s)
ϵ = default_tolerance(T)
for _ ∈ 1:s
c₁ = norm(v,Inf)
for j ∈ 1:mo
v = (Ai*v)/j
F += v
c₂ = norm(v,Inf)
c₁+c₂ ≤ ϵ*norm(F,Inf) && break
c₁ = c₂
end
F *= η
v = F
all(isfinite,v) || break
end

return F
end

end
8 changes: 4 additions & 4 deletions src/krylov_phiv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function expv!(w::AbstractVector{Tw}, t::Real, Ks::KrylovSubspace{T, U};
cache = nothing, expmethod = ExpMethodHigham2005Base()) where {Tw, T, U}
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
@assert length(w)==size(V, 1) "Dimension mismatch"
if cache == nothing
if isnothing(cache)
cache = Matrix{U}(undef, m, m)
elseif isa(cache, ExpvCache)
cache = get_cache(cache, m)
Expand Down Expand Up @@ -105,7 +105,7 @@ function expv!(w::AbstractVector{Complex{Tw}}, t::Complex{Tt}, Ks::KrylovSubspac
cache = nothing, expmethod = ExpMethodHigham2005Base()) where {Tw, Tt, T, U}
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
@assert length(w)==size(V, 1) "Dimension mismatch"
if cache === nothing
if isnothing(cache)
cache = Matrix{U}(undef, m, m)
elseif isa(cache, ExpvCache)
cache = get_cache(cache, m)
Expand Down Expand Up @@ -135,7 +135,7 @@ function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
expmethod = ExpMethodHigham2005Base()) where {Tw, T, U}
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
@assert length(w)==size(V, 1) "Dimension mismatch"
if cache === nothing
if isnothing(cache)
cache = Matrix{U}(undef, m, m)
elseif isa(cache, ExpvCache)
cache = get_cache(cache, m)
Expand Down Expand Up @@ -259,7 +259,7 @@ function phiv!(w::AbstractMatrix, t::Number, Ks::KrylovSubspace{T, U}, k::Intege
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
@assert size(w, 1)==size(V, 1) "Dimension mismatch"
@assert size(w, 2)==k + 1 "Dimension mismatch"
if cache === nothing
if isnothing(cache)
cache = PhivCache(w, m, k)
elseif !isa(cache, PhivCache)
throw(ArgumentError("Cache must be a PhivCache"))
Expand Down
8 changes: 4 additions & 4 deletions src/phi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Software (TOMS), 24(1), 130-156. Theorem 1).
function phi(z::T, k::Integer; cache = nothing,
expmethod = ExpMethodHigham2005Base()) where {T <: Number}
# Construct the matrix
if cache == nothing
if isnothing(cache)
cache = fill(zero(T), k + 1, k + 1)
else
fill!(cache, zero(T))
Expand Down Expand Up @@ -66,7 +66,7 @@ function phiv_dense!(w::AbstractMatrix{T}, A::AbstractMatrix{T},
@assert size(w, 2)==k + 1 "Dimension mismatch"
m = length(v)
# Construct the extended matrix
if cache == nothing
if isnothing(cache)
cache = fill(zero(T), m + k, m + k)
else
@assert size(cache)==(m + k, m + k) "Dimension mismatch"
Expand Down Expand Up @@ -121,7 +121,7 @@ function phi!(out::Vector{Matrix{T}}, A::AbstractMatrix{T}, k::Integer; caches =
expmethod = ExpMethodHigham2005Base()) where {T <: Number}
m = size(A, 1)
@assert length(out) == k + 1&&all(P -> size(P) == (m, m), out) "Dimension mismatch"
if caches == nothing
if isnothing(caches)
e = Vector{T}(undef, m)
W = Matrix{T}(undef, m, k + 1)
C = Matrix{T}(undef, m + k, m + k)
Expand All @@ -143,7 +143,7 @@ function phi!(out::Vector{Matrix{T}}, A::AbstractMatrix{T}, k::Integer; caches =
end
function phi!(out::Vector{Diagonal{T, V}}, A::Diagonal{T, V}, k::Integer;
caches = nothing) where {T <: Number, V <: AbstractVector{T}}
for i in 1:size(A, 1)
for i in axes(A,1)
phiz = phi(A[i, i], k; cache = caches)
for j in 1:(k + 1)
out[j][i, i] = phiz[j]
Expand Down
9 changes: 9 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ end
end
end

@testset "Static Arrays" begin
Random.seed!(0)
for N in (3,4,6,8),t in (0.1,1.0,10.0)
A = I+randn(SMatrix{N,N,Float64})/3
b = randn(SVector{N,Float64})
@test expv(t,A,b) ≈ exp(t*A)*b
end
end

@testset "Arnoldi & Krylov" begin
Random.seed!(0)
n = 20
Expand Down
Loading