Skip to content

Commit

Permalink
add Kronecker M1 solver to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jowezarek committed Oct 2, 2023
1 parent 792e019 commit b8068f7
Showing 1 changed file with 181 additions and 4 deletions.
185 changes: 181 additions & 4 deletions psydac/linalg/tests/test_kron_direct_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,26 @@
import pytest
import time
import numpy as np
from mpi4py import MPI
from psydac.ddm.cart import DomainDecomposition, CartDecomposition
from mpi4py import MPI


from scipy.sparse import csc_matrix, dia_matrix, kron
from scipy.sparse.linalg import splu
from psydac.linalg.stencil import StencilVectorSpace, StencilVector, StencilMatrix
from psydac.linalg.kron import KroneckerLinearSolver

from sympde.calculus import dot
from sympde.expr import BilinearForm, integral
from sympde.topology import Line
from sympde.topology import Cube
from sympde.topology import Derham
from sympde.topology import elements_of

from psydac.api.discretization import discretize
from psydac.ddm.cart import DomainDecomposition, CartDecomposition
from psydac.linalg.block import BlockLinearOperator
from psydac.linalg.direct_solvers import SparseSolver, BandedSolver
from psydac.linalg.kron import KroneckerLinearSolver
from psydac.linalg.solvers import inverse
from psydac.linalg.stencil import StencilVectorSpace, StencilVector, StencilMatrix

#===============================================================================
def compute_global_starts_ends(domain_decomposition, npts):
Expand Down Expand Up @@ -169,6 +182,74 @@ def compare_solve(seed, comm, npts, pads, periods, direct_solver, dtype=float, t
# compare for equality
assert np.allclose( X[localslice], X_glob[localslice], rtol=1e-8, atol=1e-8 )

def get_M1_block_kron_solver(V1, ncells, degree, periodic):
"""
Given a 3D DeRham sequenece (V0 = H(grad) --grad--> V1 = H(curl) --curl--> V2 = H(div) --div--> V3 = L2)
discreticed using ncells, degree and periodic,
domain = Cube('C', bounds1=(0, 1), bounds2=(0, 1), bounds3=(0, 1))
derham = Derham(domain)
domain_h = discretize(domain, ncells=ncells, periodic=periodic, comm=comm)
derham_h = discretize(derham, domain_h, degree=degree),
returns the inverse of the mass matrix M1 as a BlockLinearOperator consisting of three KroneckerLinearSolvers on the diagonal.
"""
# assert 3D
assert len(ncells) == 3
assert len(degree) == 3
assert len(periodic) == 3

# 1D domain to be discreticed using the respective values of ncells, degree, periodic
domain_1d = Line('L', bounds=(0,1))
derham_1d = Derham(domain_1d)

# storage for the 1D mass matrices
M0_matrices = []
M1_matrices = []

# assembly of the 1D mass matrices
for (n, p, P) in zip(ncells, degree, periodic):

domain_1d_h = discretize(domain_1d, ncells=[n], periodic=[P])
derham_1d_h = discretize(derham_1d, domain_1d_h, degree=[p])

u_1d_0, v_1d_0 = elements_of(derham_1d.V0, names='u_1d_0, v_1d_0')
u_1d_1, v_1d_1 = elements_of(derham_1d.V1, names='u_1d_1, v_1d_1')

a_1d_0 = BilinearForm((u_1d_0, v_1d_0), integral(domain_1d, u_1d_0 * v_1d_0))
a_1d_1 = BilinearForm((u_1d_1, v_1d_1), integral(domain_1d, u_1d_1 * v_1d_1))

a_1d_0_h = discretize(a_1d_0, domain_1d_h, (derham_1d_h.V0, derham_1d_h.V0))
a_1d_1_h = discretize(a_1d_1, domain_1d_h, (derham_1d_h.V1, derham_1d_h.V1))

M_1d_0 = a_1d_0_h.assemble()
M_1d_1 = a_1d_1_h.assemble()

M0_matrices.append(M_1d_0)
M1_matrices.append(M_1d_1)

V1_1 = V1[0]
V1_2 = V1[1]
V1_3 = V1[2]

B1_mat = [M1_matrices[0], M0_matrices[1], M0_matrices[2]]
B2_mat = [M0_matrices[0], M1_matrices[1], M0_matrices[2]]
B3_mat = [M0_matrices[0], M0_matrices[1], M1_matrices[2]]

B1_solvers = [matrix_to_bandsolver(Ai) for Ai in B1_mat]
B2_solvers = [matrix_to_bandsolver(Ai) for Ai in B2_mat]
B3_solvers = [matrix_to_bandsolver(Ai) for Ai in B3_mat]

B1_kron_inv = KroneckerLinearSolver(V1_1, V1_1, B1_solvers)
B2_kron_inv = KroneckerLinearSolver(V1_2, V1_2, B2_solvers)
B3_kron_inv = KroneckerLinearSolver(V1_3, V1_3, B3_solvers)

M1_block_kron_solver = BlockLinearOperator(V1, V1, ((B1_kron_inv, None, None),
(None, B2_kron_inv, None),
(None, None, B3_kron_inv)))

return M1_block_kron_solver

#===============================================================================
# tests of the direct solvers
@pytest.mark.parametrize( 'dtype', [float, complex] )
Expand Down Expand Up @@ -370,6 +451,102 @@ def test_kron_solver_nd_par(seed, dim, dtype):

npts_base = 4
compare_solve(seed, MPI.COMM_WORLD, [npts_base]*dim, [1]*dim, [False]*dim, matrix_to_sparse, dtype=dtype, transposed=False, verbose=False)

#===============================================================================

# test Kronecker solver of the M1 mass matrix of our 3D DeRham sequence, as described in the get_M1_block_kron_solver method

@pytest.mark.parametrize( 'ncells', [[8, 8, 8], [8, 16, 8]] )
@pytest.mark.parametrize( 'degree', [[2, 2, 2]] )
@pytest.mark.parametrize( 'periodic', [[True, True, True]] )
@pytest.mark.parallel
def test_3d_m1_solver(ncells, degree, periodic):

comm = MPI.COMM_WORLD
domain = Cube('C', bounds1=(0, 1), bounds2=(0, 1), bounds3=(0, 1))
derham = Derham(domain)
domain_h = discretize(domain, ncells=ncells, periodic=periodic, comm=comm)
derham_h = discretize(derham, domain_h, degree=degree)
V1 = derham_h.V1.vector_space
P0, P1, P2, P3 = derham_h.projectors()

# obtain an iterative M1 solver the usual way
u1, v1 = elements_of(derham.V1, names='u1, v1')
a1 = BilinearForm((u1, v1), integral(domain, dot(u1, v1)))
a1_h = discretize(a1, domain_h, (derham_h.V1, derham_h.V1))
M1 = a1_h.assemble()
tol = 1e-12
maxiter = 1000
M1_iterative_solver = inverse(M1, 'cg', tol = tol, maxiter=maxiter)

# obtain a direct M1 solver utilizing the Block-Kronecker structure of M1
M1_direct_solver = get_M1_block_kron_solver(V1, ncells, degree, periodic)

# obtain x and rhs = M1 @ x, both elements of derham_h.V1
def get_A_fun(n=1, m=1, A0=1e04):
"""Get the tuple A = (A1, A2, A3), where each entry is a function taking x,y,z as input."""

mu_tilde = np.sqrt(m**2 + n**2)

eta = lambda x, y, z: x**2 * (1-x)**2 * y**2 * (1-y)**2 * z**2 * (1-z)**2

u1 = lambda x, y, z: A0 * (n/mu_tilde) * np.sin(np.pi * m * x) * np.cos(np.pi * n * y)
u2 = lambda x, y, z: -A0 * (m/mu_tilde) * np.cos(np.pi * m * x) * np.sin(np.pi * n * y)
u3 = lambda x, y, z: A0 * np.sin(np.pi * m * x) * np.sin(np.pi * n * y)

A1 = lambda x, y, z: eta(x, y, z) * u1(x, y, z)
A2 = lambda x, y, z: eta(x, y, z) * u2(x, y, z)
A3 = lambda x, y, z: eta(x, y, z) * u3(x, y, z)

A = (A1, A2, A3)
return A
x = P1(get_A_fun()).coeffs
rhs = M1 @ x

# solve M1 @ x = rhs for x two ways
# pass -s to see timings
# on my local machine, executing
# mpirun -n 4 python -m pytest test_kron_direct_solver.py::test_3d_m1_solver -s
# I can report the following data:

### 4 processes, test case 1 (ncells=[8, 8, 8]):

# Solving for x using the iterative solver: 23.73982548713684 seconds
# Solving for x using the iterative solver: 23.820897102355957 seconds
# Solving for x using the iterative solver: 23.783425092697144 seconds
# Solving for x using the iterative solver: 23.71373987197876 seconds
# Solving for x using the direct solver: 0.3333120346069336 seconds
# Solving for x using the direct solver: 0.3369138240814209 seconds
# Solving for x using the direct solver: 0.33652329444885254 seconds
# Solving for x using the direct solver: 0.34088802337646484 seconds

###4 processes, test case 2 (ncells=[8, 16, 8]):
# Solving for x using the iterative solver: 82.10541296005249 seconds
# Solving for x using the iterative solver: 81.88263297080994 seconds
# Solving for x using the iterative solver: 82.07102465629578 seconds
# Solving for x using the iterative solver: 82.00282955169678 seconds
# Solving for x using the direct solver: 0.1675126552581787 seconds
# Solving for x using the direct solver: 0.17473626136779785 seconds
# Solving for x using the direct solver: 0.15992450714111328 seconds
# Solving for x using the direct solver: 0.17931437492370605 seconds

# Note that on consecutive solves, with only a slightly changing rhs and recycle=True, the iterative solver won't perform as bad anymore.

start = time.time()
x_iterative = M1_iterative_solver @ rhs
stop = time.time()
print(f"Solving for x using the iterative solver: {stop-start} seconds")

start = time.time()
x_direct = M1_direct_solver @ rhs
stop = time.time()
print(f"Solving for x using the direct solver: {stop-start} seconds")

# assert rhs_iterative is within the tolerance close to rhs, and so is rhs_direct
rhs_iterative = M1 @ x_iterative
rhs_direct = M1 @ x_direct
assert np.linalg.norm((rhs-rhs_iterative).toarray()) < tol
assert np.linalg.norm((rhs-rhs_direct).toarray()) < tol
#===============================================================================

if __name__ == '__main__':
Expand Down

0 comments on commit b8068f7

Please sign in to comment.