Skip to content

Commit

Permalink
use VectorInterface and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorVanthilt committed Dec 6, 2024
1 parent 15b163d commit 794d556
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
17 changes: 9 additions & 8 deletions src/linsolve/lsmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function linsolve(operator, b, alg::LSMR)
return linsolve(operator, b, zerovector(apply_adjoint(operator, b)), alg)

Check warning on line 3 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L2-L3

Added lines #L2 - L3 were not covered by tests
end;
function linsolve(operator, b, x₀, alg::LSMR)
u = axpby!(1, b, -1, apply_normal(operator, x₀))
u = add!!(apply_normal(operator, x₀), b, 1, -1)
β = norm(u)

# initialize GKL factorization
Expand All @@ -17,7 +17,8 @@ function linsolve(operator, b, x₀, alg::LSMR)
alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr)
istop = 0

x = copy(x₀)
# TODO: make this an explicit copy that works with the testing datatypes
x = x₀

for topit in 1:(alg.maxiter)# the outermost restart loop
# Initialize variables for 1st iteration.
Expand Down Expand Up @@ -49,8 +50,8 @@ function linsolve(operator, b, x₀, alg::LSMR)
normr = β
normAr = α * β

hbar = zero(T) * x
h = one(T) * fact.V[end]
hbar = scale(x, zero(T))
h = scale(fact.V[end], one(T))

while length(fact) < alg.krylovdim
β = normres(fact)
Expand Down Expand Up @@ -85,9 +86,9 @@ function linsolve(operator, b, x₀, alg::LSMR)
ζbar = -sbar * ζbar

# Update h, h_hat, x.
hbar = axpby!(1, h, -θbar * ρ / (ρold * ρbarold), hbar)
h = axpby!(1, v, -θnew / ρ, h)
x = axpy!(ζ /* ρbar), hbar, x)
hbar = add!!(hbar, h, 1, -θbar * ρ / (ρold * ρbarold))
h = add!!(h, v, 1, -θnew / ρ)
x = add!!(x, hbar, ζ /* ρbar), 1)

##############################################################################
##
Expand Down Expand Up @@ -187,7 +188,7 @@ function linsolve(operator, b, x₀, alg::LSMR)
end
end

u = axpby!(1, b, -1, apply_normal(operator, x))
u = add!!(apply_normal(operator, x), b, 1, -1)

istop != 0 && break

Expand Down
18 changes: 12 additions & 6 deletions test/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,35 @@ end
end

# Test LSMR complete
@testset "full lsmr" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "full lsmr ($mode)" for mode in (:vector, :inplace, :outplace)
scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :
(ComplexF64,)
@testset for T in scalartypes
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (n, n))
v = rand(T, n)
w = rand(T, n)
alg = LSMR(; orth=orth, krylovdim=2 * n, maxiter=1, atol=10 * n * eps(real(T)),
btol=10 * n * eps(real(T)))
S, info = @inferred linsolve(wrapop(A), wrapvec(v), wrapvec(w), alg)
S, info = @inferred linsolve(wrapop(A, Val(mode)), wrapvec(v, Val(mode)),
wrapvec(w, Val(mode)), alg)
@test info.converged > 0
@test v A * unwrapvec(S) + unwrapvec(info.residual)
end
end
end
@testset "iterative lsmr" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "iterative lsmr ($mode)" for mode in (:vector, :inplace, :outplace)
scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :
(ComplexF64,)
@testset for T in scalartypes
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (N, N))
v = rand(T, N)
w = rand(T, N)
alg = LSMR(; orth=orth, krylovdim=N, maxiter=50, atol=10 * N * eps(real(T)),
btol=10 * N * eps(real(T)))
S, info = @inferred linsolve(wrapop(A), wrapvec(v), wrapvec(w), alg)
S, info = @inferred linsolve(wrapop(A, Val(mode)), wrapvec(v, Val(mode)),
wrapvec(w, Val(mode)), alg)
@test info.converged > 0
@test v A * unwrapvec(S) + unwrapvec(info.residual)
end
Expand Down

0 comments on commit 794d556

Please sign in to comment.