Skip to content

Commit

Permalink
simplify and improve lsmr implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Jan 7, 2025
1 parent 4a36d1e commit c945245
Showing 1 changed file with 42 additions and 70 deletions.
112 changes: 42 additions & 70 deletions src/lssolve/lsmr.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,43 @@
function lssolve(operator, b, alg::LSMR, λ_::Real=0)
# Initial function operation and division defines number type
x₀ = apply_adjoint(operator, b)
T = typeof(inner(x₀, x₀) / inner(b, b))
r = scale(b, one(T))
β = norm(r)
x = scale(x₀, zero(T))
# Initialisation: determine number type
u₀ = b
v₀ = apply_adjoint(operator, u₀)
T = typeof(inner(v₀, v₀) / inner(u₀, u₀))
u = scale(u₀, one(T))
v = scale(v₀, one(T))
β = norm(u)
S = typeof(β)

# Algorithm parameters
maxiter = alg.maxiter
tol::S = alg.tol
λ::S = convert(S, λ_)

# Initialisation
numiter = 0
numops = 1 # operator has been applied once to determine x₀
u = scale!!(r, 1 / β)
v = apply_adjoint(operator, u)
numops += 1
u = scale!!(u, 1 / β)
v = scale!!(v, 1 / β)
α = norm(v)
v = scale!!(v, 1 / α)

# Scalar variables for the bidiagonalization
ᾱ = α
ζ̄ = α * β
ρ = one(S)
θ = zero(S)
ρ̄ = one(S)
= one(S)
= zero(S)

absζ̄ = abs(ζ̄)

# Vector variables
x = zerovector(v)
h = v
= zerovector(x)
= zerovector(v)

# Initialize variables for estimation of ‖r‖.
β̈ = β
β̇ = zero(S)
ρ̇ = one(S)
τ̃ = zero(S)
θ̃ = zero(S)
ζ = zero(S)
d = zero(S)
r = scale(u, β)
Ah = zerovector(u)
Ah̄ = zerovector(u)

normr = β
normr̄ = β
absζ̄ = abs(ζ̄)
# Algorithm parameters
numiter = 0
numops = 1 # One (adjoint) function application for v
maxiter = alg.maxiter
tol::S = alg.tol
λ::S = convert(S, λ_)

# Check for early return
if abs(ζ̄) < tol
Expand All @@ -52,13 +48,16 @@ function lssolve(operator, b, alg::LSMR, λ_::Real=0)
* ‖ Aᴴ(b - A x) - λ^2 x ‖ = $absζ̄
* number of operations = $numops"""
end
return (x, ConvergenceInfo(1, scale(u, normr), abs(ζ̄), numiter, numops))
return (x, ConvergenceInfo(1, r, abs(ζ̄), numiter, numops))

Check warning on line 51 in src/lssolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/lssolve/lsmr.jl#L51

Added line #L51 was not covered by tests
end

while true
numiter += 1
Av = apply_normal(operator, v)
Ah = add!!(Ah, Av, 1, -θ / ρ)

# βₖ₊₁ uₖ₊₁ = A vₖ - αₖ uₖ₊₁
u = add!!(apply_normal(operator, v), u, -α, 1)
u = add!!(Av, u, -α, 1)
β = norm(u)
u = scale!!(u, 1 / β)
# αₖ₊₁ vₖ₊₁ = Aᴴ uₖ₊₁ - βₖ₊₁ vₖ
Expand All @@ -82,7 +81,6 @@ function lssolve(operator, b, alg::LSMR, λ_::Real=0)

# Use a plane rotation P̄ₖ to turn Rₖᵀ to R̄ₖ
ρ̄old = ρ̄ # ρ̄ₖ₋₁
ζold = ζ # ζₖ₋₁
θ̄ =* ρ # θ̄ₖ = s̄ₖ₋₁ * ρₖ
c̄ρ =* ρ # c̄ₖ₋₁ * ρₖ
ρ̄ = hypot(c̄ρ, θ) # ρ̄ₖ = sqrt((c̄ₖ₋₁ * ρₖ)^2 + θₖ₊₁^2)
Expand All @@ -93,60 +91,34 @@ function lssolve(operator, b, alg::LSMR, λ_::Real=0)

# Update h, h̄, x
= add!!(h̄, h, 1, -θ̄ * ρ / (ρold * ρ̄old)) # h̄ₖ = hₖ - θ̄ₖ * ρₖ / (ρₖ₋₁ * ρ̄ₖ₋₁) * h̄ₖ₋₁
x = add!!(x, h̄, ζ /* ρ̄)) # xₖ = xₖ₋₁ + ζₖ / (ρₖ * ρ̄ₖ) * h̄ₖ
h = add!!(h, v, 1, -θ / ρ) # hₖ₊₁ = vₖ₊₁ - θₖ₊₁ / ρₖ * hₖ

# Estimate of ‖r‖
#-----------------
# Apply rotation P̂ₖ
β́ =* β̈ # β́ₖ = ĉₖ * β̈ₖ
β̌ = -* β̈ # β̌ₖ = -ŝₖ * β̈ₖ

# Apply rotation Pₖ
β̂ = c * β́ # β̂ₖ = cₖ * β́ₖ
β̈ = -s * β́ # β̈ₖ₊₁ = -sₖ * β́ₖ
Ah̄ = add!!(Ah̄, Ah, 1, -θ̄ * ρ / (ρold * ρ̄old)) # h̄ₖ = hₖ - θ̄ₖ * ρₖ / (ρₖ₋₁ * ρ̄ₖ₋₁) * h̄ₖ₋₁

# Construct and apply rotation P̃ₖ₋₁
ρ̃ = hypot(ρ̇, θ̄) # ρ̃ₖ₋₁ = sqrt(ρ̇ₖ₋₁^2 + θ̄ₖ^2)
= ρ̇ / ρ̃ # c̃ₖ₋₁ = ρ̇ₖ₋₁ / ρ̃ₖ₋₁
= θ̄ / ρ̃ # s̃ₖ = θ̄ₖ / ρ̃ₖ₋₁
θ̃old = θ̃ # θ̃ₖ₋₁
θ̃ =* ρ̄ # θ̃ₖ = s̃ₖ₋₁ * ρ̄ₖ
ρ̇ =* ρ̄ # ρ̇ₖ = c̃ₖ₋₁ * ρ̄ₖ
β̇ = -* β̇ +* β̂ # β̇ₖ = -s̃ₖ * β̇ₖ₋₁ + c̃ₖ₋₁ * β̂ₖ

# Update t̃ by forward substitution
τ̃ = (ζold - θ̃old * τ̃) / ρ̃ # τ̃ₖ₋₁ = (ζₖ₋₁ - θ̃ₖ₋₁ * τ̃ₖ₋₂) / ρ̃ₖ₋₁
τ̇ =- θ̃ * τ̃) / ρ̇ # τ̇ₖ = (ζₖ - θ̃ₖ * τ̃ₖ₋₁) / ρ̇ₖ
x = add!!(x, h̄, ζ /* ρ̄)) # xₖ = xₖ₋₁ + ζₖ / (ρₖ * ρ̄ₖ) * h̄ₖ
r = add!!(r, Ah̄, -ζ /* ρ̄)) # rₖ = rₖ₋₁ - ζₖ / (ρₖ * ρ̄ₖ) * Ah̄ₖ

# Compute ‖r‖ and ‖r̄‖
sqrtd = hypot(d, β̌)
normr = hypot(β̇ - τ̇, β̈)
normr̄ = hypot(sqrtd, normr)
h = add!!(h, v, 1, -θ / ρ) # hₖ₊₁ = vₖ₊₁ - θₖ₊₁ / ρₖ * hₖ
# Ah is updated in the next iteration when A v is computed

absζ̄ = abs(ζ̄)
if absζ̄ <= tol
if alg.verbosity > 0
@info """LSMR lssolve converged at iteration $numiter:
* ‖ b - A x ‖ = $normr
* ‖ [b - A x; λ x] ‖ = $normr̄
* ‖ b - A x ‖ = $(norm(r))
* ‖ x ‖ = $(norm(x))
* ‖ Aᴴ(b - A x) - λ^2 x ‖ = $absζ̄
* number of operations = $numops"""
end
# TODO: r can probably be determined and updated along the way
r = add!!(apply_normal(operator, x), b, 1, -1)
numops += 1
return (x, ConvergenceInfo(1, r, absζ̄, numiter, numops))
elseif numiter >= maxiter
if alg.verbosity > 0
normr = norm(r)
normx = norm(x)
@warn """LSMR lssolve finished without converging after $numiter iterations:
* ‖ b - A x ‖ = $normr
* ‖ [b - A x; λ x] ‖ = $normr̄
* ‖ b - A x ‖ = $(norm(r))
* ‖ x ‖ = $(norm(x))
* ‖ Aᴴ(b - A x) - λ^2 x ‖ = $absζ̄
* number of operations = $numops"""
end
r = add!!(apply_normal(operator, x), b, 1, -1)
numops += 1
return (x, ConvergenceInfo(0, r, absζ̄, numiter, numops))
end
if alg.verbosity > 1
Expand Down

0 comments on commit c945245

Please sign in to comment.