Skip to content

Commit

Permalink
Fix dtype in GMRES and MINRES (#345)
Browse files Browse the repository at this point in the history
In the iterative linear solvers take `dtype` from the domain or the
codomain of the linear operator to be "inverted". This avoids reading
the `dtype` property of the linear operator, which may not be properly
defined in some cases. Specifically, the classes `GMRES` and
`MinimumResidual` were fixed.
  • Loading branch information
e-moral-sanchez authored Oct 12, 2023
1 parent 5d2950d commit 652b29f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
6 changes: 3 additions & 3 deletions psydac/linalg/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def __init__(self, A, *, x0=None, tol=1e-6, maxiter=1000, verbose=False, recycle

assert isinstance(A, LinearOperator)
assert A.domain.dimension == A.codomain.dimension
assert A.dtype == float
assert A.domain.dtype == float
domain = A.codomain
codomain = A.domain

Expand Down Expand Up @@ -1731,7 +1731,7 @@ def __init__(self, A, *, x0=None, tol=1e-6, maxiter=100, verbose=False, recycle=
self._tmps = {key: domain.zeros() for key in ("r", "p")}

# Initialize upper Hessenberg matrix
self._H = np.zeros((self._options["maxiter"] + 1, self._options["maxiter"]), dtype=A.dtype)
self._H = np.zeros((self._options["maxiter"] + 1, self._options["maxiter"]), dtype=A.domain.dtype)
self._Q = []
self._info = None

Expand Down Expand Up @@ -1879,7 +1879,7 @@ def solve(self, b, out=None):
def solve_triangular(self, T, d):
# Backwards substitution. Assumes T is upper triangular
k = T.shape[0]
y = np.zeros((k,), dtype=self._A.dtype)
y = np.zeros((k,), dtype=self._A.domain.dtype)

for k1 in range(k):
temp = 0.
Expand Down
17 changes: 16 additions & 1 deletion psydac/linalg/tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_solver_tridiagonal(n, p, dtype, solver, verbose=False):
solv = inverse(A, solver, tol=tol, verbose=False, recycle=True)
solvt = solv.transpose()
solvh = solv.H
solv2 = inverse(A@A, solver, tol=1e-13, verbose=True, recycle=True) # Test solver of composition of operators

# Manufacture right-hand-side vector from exact solution
be = A @ xe
Expand Down Expand Up @@ -135,11 +136,21 @@ def test_solver_tridiagonal(n, p, dtype, solver, verbose=False):
assert np.array_equal(xh.toarray(), solvh_x0.toarray())
assert xh is not solvh_x0

if solver != 'pcg':
# PCG only works with operators with diagonal
xc = solv2 @ be2
solv2_x0 = solv2._options["x0"]
assert np.array_equal(xc.toarray(), solv2_x0.toarray())
assert xc is not solv2_x0


# Verify correctness of calculation: 2-norm of error
b = A @ x
b2 = A @ x2
bt = A.T @ xt
bh = A.H @ xh
if solver != 'pcg':
bc = A @ A @ xc


err = b - be
Expand All @@ -151,6 +162,10 @@ def test_solver_tridiagonal(n, p, dtype, solver, verbose=False):
errh = bh - beh
errh_norm = np.linalg.norm( errh.toarray() )

if solver != 'pcg':
errc = bc - be2
errc_norm = np.linalg.norm( errc.toarray() )

#---------------------------------------------------------------------------
# TERMINAL OUTPUT
#---------------------------------------------------------------------------
Expand Down Expand Up @@ -180,7 +195,7 @@ def test_solver_tridiagonal(n, p, dtype, solver, verbose=False):
assert err2_norm < tol
assert errt_norm < tol
assert errh_norm < tol

assert solver == 'pcg' or errc_norm < tol

# ===============================================================================
# SCRIPT FUNCTIONALITY
Expand Down

0 comments on commit 652b29f

Please sign in to comment.