diff --git a/psydac/linalg/tests/test_kron_direct_solver.py b/psydac/linalg/tests/test_kron_direct_solver.py index c578d8c93..6762ac76d 100644 --- a/psydac/linalg/tests/test_kron_direct_solver.py +++ b/psydac/linalg/tests/test_kron_direct_solver.py @@ -3,13 +3,26 @@ import pytest import time import numpy as np -from mpi4py import MPI -from psydac.ddm.cart import DomainDecomposition, CartDecomposition +from mpi4py import MPI + + from scipy.sparse import csc_matrix, dia_matrix, kron from scipy.sparse.linalg import splu -from psydac.linalg.stencil import StencilVectorSpace, StencilVector, StencilMatrix -from psydac.linalg.kron import KroneckerLinearSolver + +from sympde.calculus import dot +from sympde.expr import BilinearForm, integral +from sympde.topology import Line +from sympde.topology import Cube +from sympde.topology import Derham +from sympde.topology import elements_of + +from psydac.api.discretization import discretize +from psydac.ddm.cart import DomainDecomposition, CartDecomposition +from psydac.linalg.block import BlockLinearOperator from psydac.linalg.direct_solvers import SparseSolver, BandedSolver +from psydac.linalg.kron import KroneckerLinearSolver +from psydac.linalg.solvers import inverse +from psydac.linalg.stencil import StencilVectorSpace, StencilVector, StencilMatrix #=============================================================================== def compute_global_starts_ends(domain_decomposition, npts): @@ -169,6 +182,74 @@ def compare_solve(seed, comm, npts, pads, periods, direct_solver, dtype=float, t # compare for equality assert np.allclose( X[localslice], X_glob[localslice], rtol=1e-8, atol=1e-8 ) +def get_M1_block_kron_solver(V1, ncells, degree, periodic): + """ + Given a 3D DeRham sequenece (V0 = H(grad) --grad--> V1 = H(curl) --curl--> V2 = H(div) --div--> V3 = L2) + discreticed using ncells, degree and periodic, + + domain = Cube('C', bounds1=(0, 1), bounds2=(0, 1), bounds3=(0, 1)) + derham = Derham(domain) + domain_h = discretize(domain, ncells=ncells, periodic=periodic, comm=comm) + derham_h = discretize(derham, domain_h, degree=degree), + + returns the inverse of the mass matrix M1 as a BlockLinearOperator consisting of three KroneckerLinearSolvers on the diagonal. + """ + # assert 3D + assert len(ncells) == 3 + assert len(degree) == 3 + assert len(periodic) == 3 + + # 1D domain to be discreticed using the respective values of ncells, degree, periodic + domain_1d = Line('L', bounds=(0,1)) + derham_1d = Derham(domain_1d) + + # storage for the 1D mass matrices + M0_matrices = [] + M1_matrices = [] + + # assembly of the 1D mass matrices + for (n, p, P) in zip(ncells, degree, periodic): + + domain_1d_h = discretize(domain_1d, ncells=[n], periodic=[P]) + derham_1d_h = discretize(derham_1d, domain_1d_h, degree=[p]) + + u_1d_0, v_1d_0 = elements_of(derham_1d.V0, names='u_1d_0, v_1d_0') + u_1d_1, v_1d_1 = elements_of(derham_1d.V1, names='u_1d_1, v_1d_1') + + a_1d_0 = BilinearForm((u_1d_0, v_1d_0), integral(domain_1d, u_1d_0 * v_1d_0)) + a_1d_1 = BilinearForm((u_1d_1, v_1d_1), integral(domain_1d, u_1d_1 * v_1d_1)) + + a_1d_0_h = discretize(a_1d_0, domain_1d_h, (derham_1d_h.V0, derham_1d_h.V0)) + a_1d_1_h = discretize(a_1d_1, domain_1d_h, (derham_1d_h.V1, derham_1d_h.V1)) + + M_1d_0 = a_1d_0_h.assemble() + M_1d_1 = a_1d_1_h.assemble() + + M0_matrices.append(M_1d_0) + M1_matrices.append(M_1d_1) + + V1_1 = V1[0] + V1_2 = V1[1] + V1_3 = V1[2] + + B1_mat = [M1_matrices[0], M0_matrices[1], M0_matrices[2]] + B2_mat = [M0_matrices[0], M1_matrices[1], M0_matrices[2]] + B3_mat = [M0_matrices[0], M0_matrices[1], M1_matrices[2]] + + B1_solvers = [matrix_to_bandsolver(Ai) for Ai in B1_mat] + B2_solvers = [matrix_to_bandsolver(Ai) for Ai in B2_mat] + B3_solvers = [matrix_to_bandsolver(Ai) for Ai in B3_mat] + + B1_kron_inv = KroneckerLinearSolver(V1_1, V1_1, B1_solvers) + B2_kron_inv = KroneckerLinearSolver(V1_2, V1_2, B2_solvers) + B3_kron_inv = KroneckerLinearSolver(V1_3, V1_3, B3_solvers) + + M1_block_kron_solver = BlockLinearOperator(V1, V1, ((B1_kron_inv, None, None), + (None, B2_kron_inv, None), + (None, None, B3_kron_inv))) + + return M1_block_kron_solver + #=============================================================================== # tests of the direct solvers @pytest.mark.parametrize( 'dtype', [float, complex] ) @@ -370,6 +451,102 @@ def test_kron_solver_nd_par(seed, dim, dtype): npts_base = 4 compare_solve(seed, MPI.COMM_WORLD, [npts_base]*dim, [1]*dim, [False]*dim, matrix_to_sparse, dtype=dtype, transposed=False, verbose=False) + +#=============================================================================== + +# test Kronecker solver of the M1 mass matrix of our 3D DeRham sequence, as described in the get_M1_block_kron_solver method + +@pytest.mark.parametrize( 'ncells', [[8, 8, 8], [8, 16, 8]] ) +@pytest.mark.parametrize( 'degree', [[2, 2, 2]] ) +@pytest.mark.parametrize( 'periodic', [[True, True, True]] ) +@pytest.mark.parallel +def test_3d_m1_solver(ncells, degree, periodic): + + comm = MPI.COMM_WORLD + domain = Cube('C', bounds1=(0, 1), bounds2=(0, 1), bounds3=(0, 1)) + derham = Derham(domain) + domain_h = discretize(domain, ncells=ncells, periodic=periodic, comm=comm) + derham_h = discretize(derham, domain_h, degree=degree) + V1 = derham_h.V1.vector_space + P0, P1, P2, P3 = derham_h.projectors() + + # obtain an iterative M1 solver the usual way + u1, v1 = elements_of(derham.V1, names='u1, v1') + a1 = BilinearForm((u1, v1), integral(domain, dot(u1, v1))) + a1_h = discretize(a1, domain_h, (derham_h.V1, derham_h.V1)) + M1 = a1_h.assemble() + tol = 1e-12 + maxiter = 1000 + M1_iterative_solver = inverse(M1, 'cg', tol = tol, maxiter=maxiter) + + # obtain a direct M1 solver utilizing the Block-Kronecker structure of M1 + M1_direct_solver = get_M1_block_kron_solver(V1, ncells, degree, periodic) + + # obtain x and rhs = M1 @ x, both elements of derham_h.V1 + def get_A_fun(n=1, m=1, A0=1e04): + """Get the tuple A = (A1, A2, A3), where each entry is a function taking x,y,z as input.""" + + mu_tilde = np.sqrt(m**2 + n**2) + + eta = lambda x, y, z: x**2 * (1-x)**2 * y**2 * (1-y)**2 * z**2 * (1-z)**2 + + u1 = lambda x, y, z: A0 * (n/mu_tilde) * np.sin(np.pi * m * x) * np.cos(np.pi * n * y) + u2 = lambda x, y, z: -A0 * (m/mu_tilde) * np.cos(np.pi * m * x) * np.sin(np.pi * n * y) + u3 = lambda x, y, z: A0 * np.sin(np.pi * m * x) * np.sin(np.pi * n * y) + + A1 = lambda x, y, z: eta(x, y, z) * u1(x, y, z) + A2 = lambda x, y, z: eta(x, y, z) * u2(x, y, z) + A3 = lambda x, y, z: eta(x, y, z) * u3(x, y, z) + + A = (A1, A2, A3) + return A + x = P1(get_A_fun()).coeffs + rhs = M1 @ x + + # solve M1 @ x = rhs for x two ways + # pass -s to see timings + # on my local machine, executing + # mpirun -n 4 python -m pytest test_kron_direct_solver.py::test_3d_m1_solver -s + # I can report the following data: + + ### 4 processes, test case 1 (ncells=[8, 8, 8]): + + # Solving for x using the iterative solver: 23.73982548713684 seconds + # Solving for x using the iterative solver: 23.820897102355957 seconds + # Solving for x using the iterative solver: 23.783425092697144 seconds + # Solving for x using the iterative solver: 23.71373987197876 seconds + # Solving for x using the direct solver: 0.3333120346069336 seconds + # Solving for x using the direct solver: 0.3369138240814209 seconds + # Solving for x using the direct solver: 0.33652329444885254 seconds + # Solving for x using the direct solver: 0.34088802337646484 seconds + + ###4 processes, test case 2 (ncells=[8, 16, 8]): + # Solving for x using the iterative solver: 82.10541296005249 seconds + # Solving for x using the iterative solver: 81.88263297080994 seconds + # Solving for x using the iterative solver: 82.07102465629578 seconds + # Solving for x using the iterative solver: 82.00282955169678 seconds + # Solving for x using the direct solver: 0.1675126552581787 seconds + # Solving for x using the direct solver: 0.17473626136779785 seconds + # Solving for x using the direct solver: 0.15992450714111328 seconds + # Solving for x using the direct solver: 0.17931437492370605 seconds + + # Note that on consecutive solves, with only a slightly changing rhs and recycle=True, the iterative solver won't perform as bad anymore. + + start = time.time() + x_iterative = M1_iterative_solver @ rhs + stop = time.time() + print(f"Solving for x using the iterative solver: {stop-start} seconds") + + start = time.time() + x_direct = M1_direct_solver @ rhs + stop = time.time() + print(f"Solving for x using the direct solver: {stop-start} seconds") + + # assert rhs_iterative is within the tolerance close to rhs, and so is rhs_direct + rhs_iterative = M1 @ x_iterative + rhs_direct = M1 @ x_direct + assert np.linalg.norm((rhs-rhs_iterative).toarray()) < tol + assert np.linalg.norm((rhs-rhs_direct).toarray()) < tol #=============================================================================== if __name__ == '__main__':