Skip to content

Commit

Permalink
Remove DirectSolver, remove transposed from solve() troughout, adapt …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
jowezarek committed Oct 20, 2023
1 parent cf31b8c commit af2992e
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 100 deletions.
9 changes: 0 additions & 9 deletions docs/source/modules/linalg/direct_solvers.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
linalg.direct_solvers
=====================

* :ref:`DirectSolver <directsolver>`
* :ref:`BandedSolver <bandedsolver>`
* :ref:`SparseSolver <sparsesolver>`

.. inheritance-diagram:: psydac.linalg.direct_solvers

.. _directsolver:

DirectSolver
------------

.. autoclass:: psydac.linalg.direct_solvers.DirectSolver
:members:

.. _bandedsolver:

BandedSolver
Expand Down
14 changes: 7 additions & 7 deletions psydac/linalg/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,14 +1005,14 @@ def space(self):
pass

@abstractmethod
def transpose(self, conjugate=False):
"""
Transpose the LinearSolver.
If conjugate is True, return the Hermitian transpose.
"""
def transpose(self):
"""Return the transpose of the LinearSolver."""
pass

@abstractmethod
def solve(self, rhs, out=None, transposed=False):
def solve(self, rhs, out=None):
pass

@property
def T(self):
return self.transpose()
97 changes: 41 additions & 56 deletions psydac/linalg/direct_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,10 @@

from psydac.linalg.basic import LinearSolver

__all__ = ('DirectSolver', 'BandedSolver', 'SparseSolver')
__all__ = ('BandedSolver', 'SparseSolver')

#===============================================================================
class DirectSolver( LinearSolver ):
"""
Abstract class for direct linear solvers.
"""

#-------------------------------------
# Deferred methods
#-------------------------------------
@property
@abstractmethod
def space( self ):
pass

@abstractmethod
def transpose(self, conjugate=False):
"""
Transpose the DirectSolver.
If conjugate is True, return the Hermitian transpose.
"""
pass

@abstractmethod
def solve( self, rhs, out=None, transposed=False ):
pass

#===============================================================================
class BandedSolver ( DirectSolver ):
class BandedSolver(LinearSolver):
"""
Solve the equation Ax = b for x, assuming A is banded matrix.
Expand All @@ -56,10 +28,11 @@ class BandedSolver ( DirectSolver ):
Banded matrix.
"""
def __init__( self, u, l, bmat ):
def __init__(self, u, l, bmat, transposed=False):

self._u = u
self._l = l
self._transposed = transposed

# ... LU factorization
if bmat.dtype == np.float32:
Expand All @@ -75,7 +48,7 @@ def __init__( self, u, l, bmat ):
self._factor_function = zgbtrf
self._solver_function = zgbtrs
else:
msg = f'Cannot create a DirectSolver for bmat.dtype = {bmat.dtype}'
msg = f'Cannot create a BandedSolver for bmat.dtype = {bmat.dtype}'
raise NotImplementedError(msg)

self._bmat, self._ipiv, self._finfo = self._factor_function(bmat, l, u)
Expand All @@ -86,25 +59,40 @@ def __init__( self, u, l, bmat ):
self._dtype = bmat.dtype

@property
def finfo( self ):
def finfo(self):
return self._finfo

@property
def sinfo( self ):
def sinfo(self):
return self._sinfo

#--------------------------------------
# Abstract interface
#--------------------------------------
@property
def space( self ):
def space(self):
return self._space

def transpose(self, conjugate=False):
raise NotImplementedError('transpose() is not implemented for BandedSolvers')

def transpose(self):
cls = type(self)
obj = super().__new__(cls)

obj._u = self._l
obj._l = self._u
obj._bmat = self._bmat
obj._ipiv = self._ipiv
obj._finfo = self._finfo
obj._factor_function = self._factor_function
obj._solver_function = self._solver_function
obj._sinfo = None
obj._space = self._space
obj._dtype = self._dtype
obj._transposed = not self._transposed

return obj

#...
def solve( self, rhs, out=None, transposed=False ):
def solve(self, rhs, out=None):
"""
Solves for the given right-hand side.
Expand All @@ -118,12 +106,11 @@ def solve( self, rhs, out=None, transposed=False ):
out : ndarray | NoneType
Output vector. If given, it has to have the same shape and datatype as rhs.
transposed : bool
If and only if set to true, we solve against the transposed matrix. (supported by the underlying solver)
"""
assert rhs.T.shape[0] == self._bmat.shape[1]

transposed = self._transposed

if out is None:
preout, self._sinfo = self._solver_function(self._bmat, self._l, self._u, rhs.T, self._ipiv,
trans=transposed)
Expand All @@ -146,7 +133,7 @@ def solve( self, rhs, out=None, transposed=False ):
return out

#===============================================================================
class SparseSolver ( DirectSolver ):
class SparseSolver (LinearSolver):
"""
Solve the equation Ax = b for x, assuming A is scipy sparse matrix.
Expand All @@ -156,25 +143,26 @@ class SparseSolver ( DirectSolver ):
Generic sparse matrix.
"""
def __init__( self, spmat ):
def __init__(self, spmat):

assert isinstance( spmat, spmatrix )
assert isinstance(spmat, spmatrix)

self._space = np.ndarray
self._splu = splu( spmat.tocsc() )
self._spmat = spmat
self._splu = splu(spmat.tocsc())

#--------------------------------------
# Abstract interface
#--------------------------------------
@property
def space( self ):
def space(self):
return self._space
def transpose(self, conjugate=False):
raise NotImplementedError('transpose() is not implemented for SparseSolvers')

def transpose(self):
return SparseSolver(self._spmat.transpose())

#...
def solve( self, rhs, out=None, transposed=False ):
def solve(self, rhs, out=None):
"""
Solves for the given right-hand side.
Expand All @@ -188,21 +176,18 @@ def solve( self, rhs, out=None, transposed=False ):
out : ndarray | NoneType
Output vector. If given, it has to have the same shape and datatype as rhs.
transposed : bool
If and only if set to true, we solve against the transposed matrix. (supported by the underlying solver)
"""

assert rhs.T.shape[0] == self._splu.shape[1]

if out is None:
out = self._splu.solve( rhs.T, trans='T' if transposed else 'N' ).T
out = self._splu.solve(rhs.T).T

else:
assert out.shape == rhs.shape
assert out.dtype == rhs.dtype

# currently no in-place solve exposed
out[:] = self._splu.solve( rhs.T, trans='T' if transposed else 'N' ).T
out[:] = self._splu.solve(rhs.T).T

return out
4 changes: 2 additions & 2 deletions psydac/linalg/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __init__(self, function):
def space(self):
return np.ndarray

def transpose(self, conjugate=False):
def transpose(self):
raise NotImplementedError('transpose() is not implemented for OneDimSolvers')

def solve(self, rhs, out=None, transposed=False):
def solve(self, rhs, out=None):
if out is None:
out = np.empty_like(rhs)

Expand Down
34 changes: 14 additions & 20 deletions psydac/linalg/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def tosparse(self):
def transpose(self, conjugate=False):
new_domain = self._codomain
new_codomain = self._domain
new_solvers = [solver.transpose(conjugate=conjugate) for solver in self._solvers]
new_solvers = [solver.transpose() for solver in self._solvers]
return KroneckerLinearSolver(new_domain, new_codomain, new_solvers)

def dot(self, v, out=None):
Expand All @@ -545,7 +545,7 @@ def solvers(self):
"""
return tuple(self._solvers)

def solve(self, rhs, out=None, transposed=False):
def solve(self, rhs, out=None):
"""
Solves Ax=b where A is a Kronecker product matrix (and represented as such),
and b is a suitable vector.
Expand All @@ -564,12 +564,12 @@ def solve(self, rhs, out=None, transposed=False):
outslice = out[self._slice]

# call the actual kernel
self._solve_nd(inslice, outslice, transposed)
self._solve_nd(inslice, outslice)

out.update_ghost_regions()
return out

def _solve_nd(self, inslice, outslice, transposed):
def _solve_nd(self, inslice, outslice):
"""
The internal solve loop. Can handle arbitrary dimensions.
"""
Expand All @@ -582,14 +582,14 @@ def _solve_nd(self, inslice, outslice, transposed):
# internal passes
for i in range(self._ndim - 1):
# solve direction
self._solver_passes[i].solve_pass(temp1, temp2, transposed)
self._solver_passes[i].solve_pass(temp1, temp2)

# reorder and swap
self._reorder_temp_to_temp(temp1, temp2, i)
temp1, temp2 = temp2, temp1

# last pass
self._solver_passes[-1].solve_pass(temp1, temp2, transposed)
self._solver_passes[-1].solve_pass(temp1, temp2)

# copy to output
self._reorder_temp_to_outslice(temp1, outslice)
Expand Down Expand Up @@ -634,7 +634,7 @@ class KroneckerSolverSerialPass:
Parameters
----------
solver : DirectSolver
solver : BandedSolver or SparseSolver
The internally used solver class.
nglobal : int
Expand All @@ -659,7 +659,7 @@ def required_memory(self):
"""
return self._datasize

def solve_pass(self, workmem, tempmem, transposed):
def solve_pass(self, workmem, tempmem):
"""
Solves the data available in workmem, assuming that all data is available locally.
Expand All @@ -672,16 +672,13 @@ def solve_pass(self, workmem, tempmem, transposed):
tempmem : ndarray
Ignored, it exists for compatibility with the parallel solver.
transposed : bool
True, if and only if we want to solve against the transposed matrix instead.
"""
# reshape necessary memory in column-major
view = workmem[:self._datasize]
view.shape = (self._numrhs,self._dimrhs)

# call solver in in-place mode
self._solver.solve(view, out=view, transposed=transposed)
self._solver.solve(view, out=view)

class KroneckerSolverParallelPass:
"""
Expand All @@ -701,7 +698,7 @@ class KroneckerSolverParallelPass:
Parameters
----------
solver : DirectSolver
solver : BandedSolver or SparseSolver
The internally used solver class.
mpi_type : MPI type
Expand Down Expand Up @@ -846,7 +843,7 @@ def _contiguous_to_blocked(self, blocked, contiguous):
contiguouspart.shape = (self._mlocal,end-start)
contiguouspart[:] = blocked_view[:,start:end]

def solve_pass(self, workmem, tempmem, transposed):
def solve_pass(self, workmem, tempmem):
"""
Solves the data available in workmem in a distributed manner, using MPI_Alltoallv.
Expand All @@ -859,9 +856,6 @@ def solve_pass(self, workmem, tempmem, transposed):
tempmem : ndarray
Temporary array of the same minimum size as workmem.
transposed : bool
True, if and only if we want to solve against the transposed matrix instead.
"""
# preparation
sourceargs = [workmem[:self._localsize], self._source_transfer, self._mpi_type]
Expand All @@ -874,7 +868,7 @@ def solve_pass(self, workmem, tempmem, transposed):
self._blocked_to_contiguous(workmem, tempmem)

# actual solve (source contains the data)
self._serialsolver.solve_pass(workmem, tempmem, transposed)
self._serialsolver.solve_pass(workmem, tempmem)

# ordered stripes -> blocked stripes
self._contiguous_to_blocked(workmem, tempmem)
Expand All @@ -883,7 +877,7 @@ def solve_pass(self, workmem, tempmem, transposed):
self._comm.Alltoallv(targetargs, sourceargs)

#==============================================================================
def kronecker_solve(solvers, rhs, out=None, transposed=False):
def kronecker_solve(solvers, rhs, out=None):
"""
Solve linear system Ax=b with A=kron( A_n, A_{n-1}, ..., A_2, A_1 ), given
$n$ separate linear solvers $L_n$ for the 1D problems $A_n x_n = b_n$:
Expand Down Expand Up @@ -914,4 +908,4 @@ def kronecker_solve(solvers, rhs, out=None, transposed=False):
out = StencilVector(rhs.space)

kronsolver = KroneckerLinearSolver(rhs.space, rhs.space, solvers)
return kronsolver.solve(rhs, out=out, transposed=transposed)
return kronsolver.solve(rhs, out=out)
Loading

0 comments on commit af2992e

Please sign in to comment.