From 794d55670023e4b153003065e5621e9bdb4f04f4 Mon Sep 17 00:00:00 2001 From: victor Date: Fri, 6 Dec 2024 14:23:58 +0100 Subject: [PATCH] use VectorInterface and update tests --- src/linsolve/lsmr.jl | 17 +++++++++-------- test/linsolve.jl | 18 ++++++++++++------ 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/linsolve/lsmr.jl b/src/linsolve/lsmr.jl index 600c394..c7dca81 100644 --- a/src/linsolve/lsmr.jl +++ b/src/linsolve/lsmr.jl @@ -3,7 +3,7 @@ function linsolve(operator, b, alg::LSMR) return linsolve(operator, b, zerovector(apply_adjoint(operator, b)), alg) end; function linsolve(operator, b, x₀, alg::LSMR) - u = axpby!(1, b, -1, apply_normal(operator, x₀)) + u = add!!(apply_normal(operator, x₀), b, 1, -1) β = norm(u) # initialize GKL factorization @@ -17,7 +17,8 @@ function linsolve(operator, b, x₀, alg::LSMR) alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr) istop = 0 - x = copy(x₀) + # TODO: make this an explicit copy that works with the testing datatypes + x = x₀ for topit in 1:(alg.maxiter)# the outermost restart loop # Initialize variables for 1st iteration. @@ -49,8 +50,8 @@ function linsolve(operator, b, x₀, alg::LSMR) normr = β normAr = α * β - hbar = zero(T) * x - h = one(T) * fact.V[end] + hbar = scale(x, zero(T)) + h = scale(fact.V[end], one(T)) while length(fact) < alg.krylovdim β = normres(fact) @@ -85,9 +86,9 @@ function linsolve(operator, b, x₀, alg::LSMR) ζbar = -sbar * ζbar # Update h, h_hat, x. - hbar = axpby!(1, h, -θbar * ρ / (ρold * ρbarold), hbar) - h = axpby!(1, v, -θnew / ρ, h) - x = axpy!(ζ / (ρ * ρbar), hbar, x) + hbar = add!!(hbar, h, 1, -θbar * ρ / (ρold * ρbarold)) + h = add!!(h, v, 1, -θnew / ρ) + x = add!!(x, hbar, ζ / (ρ * ρbar), 1) ############################################################################## ## @@ -187,7 +188,7 @@ function linsolve(operator, b, x₀, alg::LSMR) end end - u = axpby!(1, b, -1, apply_normal(operator, x)) + u = add!!(apply_normal(operator, x), b, 1, -1) istop != 0 && break diff --git a/test/linsolve.jl b/test/linsolve.jl index b55fb35..89167fa 100644 --- a/test/linsolve.jl +++ b/test/linsolve.jl @@ -53,29 +53,35 @@ end end # Test LSMR complete -@testset "full lsmr" begin - @testset for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "full lsmr ($mode)" for mode in (:vector, :inplace, :outplace) + scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) : + (ComplexF64,) + @testset for T in scalartypes @testset for orth in (cgs2, mgs2, cgsr, mgsr) A = rand(T, (n, n)) v = rand(T, n) w = rand(T, n) alg = LSMR(; orth=orth, krylovdim=2 * n, maxiter=1, atol=10 * n * eps(real(T)), btol=10 * n * eps(real(T))) - S, info = @inferred linsolve(wrapop(A), wrapvec(v), wrapvec(w), alg) + S, info = @inferred linsolve(wrapop(A, Val(mode)), wrapvec(v, Val(mode)), + wrapvec(w, Val(mode)), alg) @test info.converged > 0 @test v ≈ A * unwrapvec(S) + unwrapvec(info.residual) end end end -@testset "iterative lsmr" begin - @testset for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "iterative lsmr ($mode)" for mode in (:vector, :inplace, :outplace) + scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) : + (ComplexF64,) + @testset for T in scalartypes @testset for orth in (cgs2, mgs2, cgsr, mgsr) A = rand(T, (N, N)) v = rand(T, N) w = rand(T, N) alg = LSMR(; orth=orth, krylovdim=N, maxiter=50, atol=10 * N * eps(real(T)), btol=10 * N * eps(real(T))) - S, info = @inferred linsolve(wrapop(A), wrapvec(v), wrapvec(w), alg) + S, info = @inferred linsolve(wrapop(A, Val(mode)), wrapvec(v, Val(mode)), + wrapvec(w, Val(mode)), alg) @test info.converged > 0 @test v ≈ A * unwrapvec(S) + unwrapvec(info.residual) end